diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1164,23 +1164,36 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + TypesMatchWith<"type of 'padding' matches element type of 'source'", + "source", "padding", + "$_self.cast().getElementType()">, + // TODO: This can be only used once "broadcast"-behavior is retired and the + // mask type becomes unambiguous, + // TypesMatchWith< + // "type of 'mask' matches type of 'vector' with an i1 elt type", + // "vector", "mask", + // [{ + // VectorType::get($_self.cast().getShape(), + // IntegerType::get($_self.getContext(), 1)) + // }]>, AttrSizedOperandSegments ]>, Arguments<(ins AnyShaped:$source, Variadic:$indices, - AffineMapAttr:$permutation_map, + OptionalAttr:$permutation_map, AnyType:$padding, Optional>:$mask, OptionalAttr:$in_bounds)>, Results<(outs AnyVectorOfAnyRank:$vector)> { - let summary = "Reads a supervector from memory into an SSA vector value."; + let summary = "Reads an n-D from memory into an SSA vector value."; let description = [{ - The `vector.transfer_read` op performs a read from a slice within a + The `vector.transfer_read` op performs a read from a slice of a [MemRef](../LangRef.md#memref-type) or a Ranked [Tensor](../LangRef.md#tensor-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same base elemental type. + into an n-D [vector](../LangRef.md#vector-type) of the same base elemental + type. A memref/tensor operand with vector element type, must have its vector element type match a suffix (shape and element type) of the vector (e.g. @@ -1190,7 +1203,7 @@ supplied as the operands `[1 .. 1 + rank(memref/tensor))` that defines the starting point of the transfer (e.g. `%A[%i0, %i1, %i2]`). - The permutation_map [attribute](../LangRef.md#attributes) is an + The permutation_map [attribute](../LangRef.md#attributes) is an optional [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and omitted from parsing and printing if it is the canonical minor identity map @@ -1217,14 +1230,14 @@ be lowered to a simple load if all dimensions are specified to be within bounds and no `mask` was specified. - This operation is called 'read' by opposition to 'load' because the - super-vector granularity is generally not representable with a single - hardware register. A `vector.transfer_read` is thus a mid-level abstraction - that supports super-vectorization with non-effecting padding for full-tile - only operations. + This operation is called 'read' by opposition to 'load' because the n-D + vector granularity is generally not representable with a single hardware + register. + A `vector.transfer_read` is thus a mid-level abstraction that supports + super-vectorization with non-effecting padding for full-tile only operations. More precisely, let's dive deeper into the permutation_map for the following - MLIR: + MLIR snippet: ```mlir vector.transfer_read %A[%expr1, %expr2, %expr3, %expr4] @@ -1232,59 +1245,55 @@ memref, vector<3x4x5xf32> ``` - This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3, - %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice - is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` + This operation always reads a slice starting at + `%A[%expr1, %expr2, %expr3, %expr4]`. + The size of the slice is 3 along d2 and 5 along d0, so the slice is: + `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` - That slice needs to be read into a `vector<3x4x5xf32>`. Since the - permutation map is not full rank, there must be a broadcast along vector - dimension `1`. + That slice needs to be read into a `vector<3x4x5xf32>`. Since the permutation + map is not full rank, there must be a broadcast along vector dimension `1`. A notional lowering of vector.transfer_read could generate code resembling: ```mlir // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + %tmp = alloc() : memref<3x4x5xf32> + %view_in_tmp = "element_type_cast"(%tmp) : memref> for %i = 0 to 3 { affine.for %j = 0 to 4 { affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref - store %tmp[%i, %j, %k] : vector<3x4x5xf32> + store %a, %tmp[%i, %j, %k] : memref<3x4x5xf32> }}} %c0 = arith.constant 0 : index - %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32> + %vec = load %view_in_tmp[] : vector<3x4x5xf32> ``` - - On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that - the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are - actually transferred between `%A` and `%tmp`. - - Alternatively, if a notional vector broadcast operation were available, the - lowered code would resemble: + Notice that the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` + values are actually transferred between `%A` and `%tmp`. + + Alternatively, the lowered code could use a vector.broadcast and resemble: ```mlir // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + %tmp = alloc() : memref<3x5xf32> + %view_in_tmp = "element_type_cast"(%tmp) : memref> for %i = 0 to 3 { affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref - store %tmp[%i, 0, %k] : vector<3x4x5xf32> + store %tmp[%i, %k] : memref<3x5xf32> }} %c0 = arith.constant 0 : index - %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32> - %vec = broadcast %tmpvec, 1 : vector<3x4x5xf32> + %tmpvec = load %view_in_tmp[%c0] : vector<3x5xf32> + %tmpvec2 = vector.broadcast %tmpvec : vector<3x5xf32> to vector<4x3x5xf32> + %vec = vector.transpose %tmpvec2, [1, 0, 2] : vector<4x3x5xf32> to vector<3x4x5xf32> ``` - where `broadcast` broadcasts from element 0 to all others along the - specified dimension. This time, the temporary storage footprint is `3 * 5` - values which is the same amount of data as the `3 * 5` values transferred. - An additional `1` broadcast is required. On a GPU this broadcast could be - implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`. - + This time, the temporary storage footprint is `3 * 5` values which is the + same amount of data as the `3 * 5` values transferred, at the expense of an + additional vector.broadcast and vector.transpose. + Syntax ``` operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list @@ -1361,12 +1370,24 @@ ]; let extraClassDeclaration = [{ + AffineMap getPermutationMapOrMinorIdentity() { + if (getPermutationMap()) return *getPermutationMap(); + return getTransferMinorIdentityMap(getSource().getType().cast(), + getVector().getType().cast()); + } + // MaskableOpInterface methods. bool supportsPassthru() { return true; } }]; + let assemblyFormat = [{ + $source `[` $indices `]` + `,` $padding + (`,` $mask^ `:` type($mask))? + custom( $permutation_map, $in_bounds, attr-dict ) + `:` type($source) `,` type($vector) + }]; let hasCanonicalizer = 1; - let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -1498,6 +1519,10 @@ ]; let extraClassDeclaration = [{ + AffineMap getPermutationMapOrMinorIdentity() { + return getPermutationMap(); + } + /// This method is added to maintain uniformity with load/store /// ops of other dialects. Value getValue() { return getVector(); } diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -111,7 +111,7 @@ /*retTy=*/"::mlir::AffineMap", /*methodName=*/"permutation_map", /*args=*/(ins), - /*methodBody=*/"return $_op.getPermutationMap();" + /*methodBody=*/"return $_op.getPermutationMapOrMinorIdentity();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -121,7 +121,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto expr = $_op.getPermutationMap().getResult(idx); + auto expr = $_op.getPermutationMapOrMinorIdentity().getResult(idx); return expr.template isa<::mlir::AffineConstantExpr>() && expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0; }] @@ -177,7 +177,7 @@ /*defaultImplementation=*/[{ return $_op.getMask() ? ::mlir::vector::detail::transferMaskType( - $_op.getVectorType(), $_op.getPermutationMap()) + $_op.getVectorType(), $_op.getPermutationMapOrMinorIdentity()) : ::mlir::VectorType(); }] >, @@ -189,7 +189,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.getPermutationMap().getNumResults();" + "return $_op.getPermutationMapOrMinorIdentity().getNumResults();" >, InterfaceMethod< /*desc=*/[{ Return the number of leading shaped dimensions that do not @@ -260,8 +260,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector dimSizes($_op.getPermutationMap().getNumDims(), 0); - for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(), + SmallVector dimSizes($_op.getPermutationMapOrMinorIdentity().getNumDims(), 0); + for (auto vecDims : llvm::zip($_op.getPermutationMapOrMinorIdentity().getResults(), $_op.getVectorType().getShape())) { AffineExpr dim = std::get<0>(vecDims); int64_t size = std::get<1>(vecDims); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -51,7 +51,7 @@ indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); Location loc = xferOp.getLoc(); unsigned offsetsIdx = 0; - for (auto expr : xferOp.getPermutationMap().getResults()) { + for (auto expr : xferOp.getPermutationMapOrMinorIdentity().getResults()) { if (auto dim = expr.template dyn_cast()) { Value prevIdx = indices[dim.getPosition()]; SmallVector dims(dimValues.begin(), dimValues.end()); @@ -121,7 +121,7 @@ return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; - AffineMap map = readOp.getPermutationMap(); + AffineMap map = readOp.getPermutationMapOrMinorIdentity(); OpBuilder b(readOp.getContext()); AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); AffineExpr zero = b.getAffineConstantExpr(0); @@ -150,7 +150,7 @@ if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. - if (!writeOp.getPermutationMap().isMinorIdentity()) + if (!writeOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return false; return true; } @@ -397,7 +397,7 @@ AffineMap permutationMap = AffineMap::getPermutationMap(permU, op.getContext()); AffineMap newMap = - permutationMap.compose(transferReadOp.getPermutationMap()); + permutationMap.compose(transferReadOp.getPermutationMapOrMinorIdentity()); rewriter.replaceOpWithNewOp( op, op.getType(), transferReadOp.getSource(), transferReadOp.getIndices(), AffineMapAttr::get(newMap), @@ -433,7 +433,7 @@ assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); - AffineMap map = op.getPermutationMap(); + AffineMap map = op.getPermutationMapOrMinorIdentity(); // Handle broadcast by setting the stride to 0. if (map.getResult(0).isa()) { assert(map.getResult(0).cast().getValue() == 0); @@ -519,7 +519,7 @@ FailureOr params = nvgpu::getLdMatrixParams( *warpMatrixInfo, - /*transpose=*/!op.getPermutationMap().isMinorIdentity()); + /*transpose=*/!op.getPermutationMapOrMinorIdentity().isMinorIdentity()); if (failed(params)) { return op->emitError() << "failed to convert vector.transfer_read to ldmatrix; this op " @@ -541,7 +541,7 @@ indices); nvgpu::LdMatrixOp newOp = builder.create( loc, vectorType, op.getSource(), indices, - !op.getPermutationMap().isMinorIdentity(), params->numTiles); + !op.getPermutationMapOrMinorIdentity().isMinorIdentity(), params->numTiles); valueMapping[op] = newOp->getResult(0); return success(); } @@ -574,7 +574,7 @@ builder.getZeroAttr(vectorType.getElementType())); Value result = builder.create(op.getLoc(), fill, vectorType); - bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); + bool isTransposeLoad = !op.getPermutationMapOrMinorIdentity().isMinorIdentity(); // If we are not transposing, then we can use vectorized loads. Otherwise, we // must load each element individually. @@ -654,7 +654,7 @@ // When we are transposing the B operand, ldmatrix will only work if we have // at least 8 rows to read and the width to read for the transpose is 128 // bits. - if (!op.getPermutationMap().isMinorIdentity() && + if (!op.getPermutationMapOrMinorIdentity().isMinorIdentity() && (bitWidth != 16 || vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) isLdMatrixCompatible = false; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -57,7 +57,7 @@ static Optional unpackedDim(OpTy xferOp) { // TODO: support 0-d corner case. assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); - auto map = xferOp.getPermutationMap(); + auto map = xferOp.getPermutationMapOrMinorIdentity(); if (auto expr = map.getResult(0).template dyn_cast()) { return expr.getPosition(); } @@ -73,7 +73,7 @@ static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { // TODO: support 0-d corner case. assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); - auto map = xferOp.getPermutationMap(); + auto map = xferOp.getPermutationMapOrMinorIdentity(); return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), b.getContext()); } @@ -1096,7 +1096,7 @@ get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, SmallVector &memrefIndices) { auto indices = xferOp.getIndices(); - auto map = xferOp.getPermutationMap(); + auto map = xferOp.getPermutationMapOrMinorIdentity(); assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); memrefIndices.append(indices.begin(), indices.end()); @@ -1228,7 +1228,7 @@ // TODO: support 0-d corner case. if (xferOp.getTransferRank() == 0) return failure(); - auto map = xferOp.getPermutationMap(); + auto map = xferOp.getPermutationMapOrMinorIdentity(); auto memRefType = xferOp.getShapedType().template dyn_cast(); if (!memRefType) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -341,7 +341,7 @@ transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(), sourceIndices, getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferReadOp.getPermutationMap()), + transferReadOp.getPermutationMapOrMinorIdentity()), transferReadOp.getPadding(), /*mask=*/Value(), transferReadOp.getInBoundsAttr()); }) @@ -446,7 +446,7 @@ rewriter.replaceOpWithNewOp( op, op.getValue(), subViewOp.getSource(), sourceIndices, getPermutationMapAttr(rewriter.getContext(), subViewOp, - op.getPermutationMap()), + op.getPermutationMapOrMinorIdentity()), op.getInBoundsAttr()); }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -30,6 +31,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" @@ -157,7 +159,8 @@ return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() && !read.getMask() && defWrite.getIndices() == read.getIndices() && defWrite.getVectorType() == read.getVectorType() && - defWrite.getPermutationMap() == read.getPermutationMap(); + defWrite.getPermutationMapOrMinorIdentity() == + read.getPermutationMapOrMinorIdentity(); } bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, @@ -165,7 +168,8 @@ return priorWrite.getIndices() == write.getIndices() && priorWrite.getMask() == write.getMask() && priorWrite.getVectorType() == write.getVectorType() && - priorWrite.getPermutationMap() == write.getPermutationMap(); + priorWrite.getPermutationMapOrMinorIdentity() == + write.getPermutationMapOrMinorIdentity(); } bool mlir::vector::isDisjointTransferIndices( @@ -2976,73 +2980,46 @@ p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } -void TransferReadOp::print(OpAsmPrinter &p) { - p << " " << getSource() << "[" << getIndices() << "], " << getPadding(); - if (getMask()) - p << ", " << getMask(); - printTransferAttrs(p, *this); - p << " : " << getShapedType() << ", " << getVectorType(); -} +static void printCustomTransferAttrs(OpAsmPrinter &printer, Operation *op, + AffineMapAttr permutationMapAttr, + ArrayAttr inBoundsAttr, + DictionaryAttr attrs) { + auto xferOp = cast(op); + SmallVector elidedAttrs; + elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); -ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { - auto &builder = parser.getBuilder(); - SMLoc typesLoc; - OpAsmParser::UnresolvedOperand sourceInfo; - SmallVector indexInfo; - OpAsmParser::UnresolvedOperand paddingInfo; - SmallVector types; - OpAsmParser::UnresolvedOperand maskInfo; - // Parsing with support for paddingValue. - if (parser.parseOperand(sourceInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseComma() || parser.parseOperand(paddingInfo)) - return failure(); - ParseResult hasMask = parser.parseOptionalComma(); - if (hasMask.succeeded()) { - if (parser.parseOperand(maskInfo)) - return failure(); + SmallVector components{attrs.getValue().begin(), + attrs.getValue().end()}; + // If there is an inBoundsAttr and it all false, it is conservatively elided + // from the dictionary. + if (!permutationMapAttr || + !permutationMapAttr.getAffineMap().isMinorIdentity()) { + elidedAttrs.push_back(xferOp.getPermutationMapAttrStrName()); } - if (parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) - return failure(); - if (types.size() != 2) - return parser.emitError(typesLoc, "requires two types"); - auto indexType = builder.getIndexType(); - auto shapedType = types[0].dyn_cast(); - if (!shapedType || !shapedType.isa()) - return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - VectorType vectorType = types[1].dyn_cast(); - if (!vectorType) - return parser.emitError(typesLoc, "requires vector type"); - auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); - Attribute mapAttr = result.attributes.get(permutationAttrName); - if (!mapAttr) { - auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); - // Update `mapAttr` that is used later to determine mask type. - mapAttr = AffineMapAttr::get(permMap); - result.attributes.set(permutationAttrName, mapAttr); + // If there is an inBoundsAttr and it all false, it is conservatively elided + // from the dictionary. + if (!inBoundsAttr || llvm::all_of(inBoundsAttr.getAsRange(), + [](BoolAttr b) { return !b.getValue(); })) { + elidedAttrs.push_back(xferOp.getInBoundsAttrStrName()); } - if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || - parser.resolveOperands(indexInfo, indexType, result.operands) || - parser.resolveOperand(paddingInfo, shapedType.getElementType(), - result.operands)) + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); +} + +static ParseResult parseCustomTransferAttrs(OpAsmParser &parser, + AffineMapAttr &permutationMapAttr, + ArrayAttr &inBoundsAttr, + NamedAttrList &attrs) { + if (failed(parser.parseOptionalAttrDict(attrs))) return failure(); - if (hasMask.succeeded()) { - if (shapedType.getElementType().dyn_cast()) - return parser.emitError( - maskInfo.location, "does not support masks with vector element type"); - auto map = mapAttr.dyn_cast().getValue(); - // Instead of adding the mask type as an op type, compute it based on the - // vector type and the permutation map (to keep the type signature small). - auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); - if (parser.resolveOperand(maskInfo, maskType, result.operands)) - return failure(); + if (auto attr = attrs.get("permutation_map")) { + permutationMapAttr = attr.dyn_cast_or_null(); + attrs.erase("permutation_map"); } - result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {1, static_cast(indexInfo.size()), 1, - static_cast(hasMask.succeeded())})); - return parser.addTypeToList(vectorType, result.types); + if (auto attr = attrs.get("in_bounds")) { + inBoundsAttr = attr.dyn_cast_or_null(); + attrs.erase("in_bounds"); + } + return success(); } LogicalResult TransferReadOp::verify() { @@ -3051,7 +3028,7 @@ VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto paddingType = getPadding().getType(); - auto permutationMap = getPermutationMap(); + auto permutationMap = getPermutationMapOrMinorIdentity(); auto sourceElementType = shapedType.getElementType(); if (static_cast(getIndices().size()) != shapedType.getRank()) @@ -3107,7 +3084,7 @@ // TODO: Be less conservative. if (op.getTransferRank() == 0) return failure(); - AffineMap permutationMap = op.getPermutationMap(); + AffineMap permutationMap = op.getPermutationMapOrMinorIdentity(); bool changed = false; SmallVector newInBounds; newInBounds.reserve(op.getTransferRank()); @@ -3230,7 +3207,7 @@ return failure(); if (xferOp.hasOutOfBoundsDim()) return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) + if (!xferOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); if (xferOp.getMask()) return failure(); @@ -3338,8 +3315,10 @@ if (!vec) return failure(); SmallVector permutation; - AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); - AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); + AffineMap readMap = + compressUnusedDims(readOp.getPermutationMapOrMinorIdentity()); + AffineMap writeMap = + compressUnusedDims(defWrite.getPermutationMapOrMinorIdentity()); AffineMap map = readMap.compose(writeMap); if (map.getNumResults() == 0) return failure(); @@ -3543,8 +3522,8 @@ if (read.getTransferRank() == 0) return failure(); // For now, only accept minor identity. Future: composition is minor identity. - if (!read.getPermutationMap().isMinorIdentity() || - !write.getPermutationMap().isMinorIdentity()) + if (!read.getPermutationMapOrMinorIdentity().isMinorIdentity() || + !write.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); // Bail on mismatching ranks. if (read.getTransferRank() != write.getTransferRank()) @@ -3578,7 +3557,8 @@ vector::TransferWriteOp write) { return read.getSource() == write.getSource() && read.getIndices() == write.getIndices() && - read.getPermutationMap() == write.getPermutationMap() && + read.getPermutationMapOrMinorIdentity() == + write.getPermutationMapOrMinorIdentity() && read.getVectorType() == write.getVectorType() && !read.getMask() && !write.getMask(); } @@ -3732,7 +3712,7 @@ if (!llvm::equal(xferOp.getVectorType().getShape(), xferOp.getShapedType().getShape())) return failure(); - if (!xferOp.getPermutationMap().isIdentity()) + if (!xferOp.getPermutationMapOrMinorIdentity().isIdentity()) return failure(); // Bail on illegal rank-reduction: we need to check that the rank-reduced @@ -3848,16 +3828,18 @@ assert(transferOp.getVectorType().hasStaticShape() && "expected vector to have a static shape"); ArrayRef vectorShape = transferOp.getVectorType().getShape(); - SmallVector resultShape = applyPermutationMap( - transferOp.getPermutationMap(), transferOp.getShapedType().getShape()); + SmallVector resultShape = + applyPermutationMap(transferOp.getPermutationMapOrMinorIdentity(), + transferOp.getShapedType().getShape()); if (transferOp.getMask() || !vectorShape.equals(resultShape)) { return rewriter.notifyMatchFailure( insertOp, "TransferWriteOp may not write the full tensor."); } // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. - SmallVector newResultShape = applyPermutationMap( - transferOp.getPermutationMap(), insertOp.getSourceType().getShape()); + SmallVector newResultShape = + applyPermutationMap(transferOp.getPermutationMapOrMinorIdentity(), + insertOp.getSourceType().getShape()); SmallVector newInBounds; for (const auto &en : enumerate(newResultShape)) newInBounds.push_back(en.value() == vectorShape[en.index()]); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -56,8 +56,8 @@ return failure(); replaceOpWithNewBufferizedOp( rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(), - readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), - readOp.getInBoundsAttr()); + AffineMapAttr::get(readOp.getPermutationMapOrMinorIdentity()), + readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -81,8 +81,10 @@ // Replace the `vector.mask` operation. rewriter.replaceOpWithNewOp( maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), - readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), - maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr())); + readOp.getIndices(), + AffineMapAttr::get(readOp.getPermutationMapOrMinorIdentity()), + readOp.getPadding(), maskingOp.getMask(), + readOp.getInBounds().value_or(ArrayAttr())); return success(); } }; @@ -103,8 +105,9 @@ // Replace the `vector.mask` operation. rewriter.replaceOpWithNewOp( maskingOp.getOperation(), resultType, writeOp.getVector(), - writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), - maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr())); + writeOp.getSource(), writeOp.getIndices(), + writeOp.getPermutationMapOrMinorIdentity(), maskingOp.getMask(), + writeOp.getInBounds().value_or(ArrayAttr())); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -498,7 +498,7 @@ auto newWarpOp = newWriteOp.getVector().getDefiningOp(); rewriter.setInsertionPoint(newWriteOp); - AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); + AffineMap indexMap = map.compose(newWriteOp.getPermutationMapOrMinorIdentity()); Location loc = newWriteOp.getLoc(); SmallVector indices(newWriteOp.getIndices().begin(), newWriteOp.getIndices().end()); @@ -738,7 +738,7 @@ auto sequentialType = read.getResult().getType().cast(); auto distributedType = distributedVal.getType().cast(); AffineMap map = calculateImplicitMap(sequentialType, distributedType); - AffineMap indexMap = map.compose(read.getPermutationMap()); + AffineMap indexMap = map.compose(read.getPermutationMapOrMinorIdentity()); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(warpOp); for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -207,7 +207,7 @@ if (newType == oldType) return failure(); - AffineMap oldMap = read.getPermutationMap(); + AffineMap oldMap = read.getPermutationMapOrMinorIdentity(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = @@ -255,7 +255,7 @@ return failure(); int64_t dropDim = oldType.getRank() - newType.getRank(); - AffineMap oldMap = write.getPermutationMap(); + AffineMap oldMap = write.getPermutationMapOrMinorIdentity(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -310,7 +310,7 @@ // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); - if (!transferReadOp.getPermutationMap().isMinorIdentity()) + if (!transferReadOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) @@ -353,7 +353,7 @@ // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); - if (!transferWriteOp.getPermutationMap().isMinorIdentity()) + if (!transferWriteOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) @@ -464,7 +464,7 @@ // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); - if (!transferReadOp.getPermutationMap().isMinorIdentity()) + if (!transferReadOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); if (transferReadOp.getMask()) return failure(); @@ -524,7 +524,7 @@ // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); - if (!transferWriteOp.getPermutationMap().isMinorIdentity()) + if (!transferWriteOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); if (transferWriteOp.getMask()) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -62,7 +62,7 @@ return failure(); SmallVector permutation; - AffineMap map = op.getPermutationMap(); + AffineMap map = op.getPermutationMapOrMinorIdentity(); if (map.getNumResults() == 0) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -151,7 +151,7 @@ return failure(); SmallVector permutation; - AffineMap map = op.getPermutationMap(); + AffineMap map = op.getPermutationMapOrMinorIdentity(); if (map.isMinorIdentity()) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -209,7 +209,7 @@ if (op.getTransferRank() == 0) return failure(); - AffineMap map = op.getPermutationMap(); + AffineMap map = op.getPermutationMapOrMinorIdentity(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { auto dimExpr = expr.dyn_cast(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2090,7 +2090,7 @@ // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. // We let the 0-d corner case pass-through as it is supported. - if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( + if (!read.getPermutationMapOrMinorIdentity().isMinorIdentityWithBroadcasting( &broadcastedDims)) return failure(); @@ -2233,7 +2233,7 @@ // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. if ( // pass-through for the 0-d corner case. - !write.getPermutationMap().isMinorIdentity()) + !write.getPermutationMapOrMinorIdentity().isMinorIdentity()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "permutation map is not minor identity: " << write; }); @@ -2676,7 +2676,7 @@ if (!srcType || !srcType.hasStaticShape()) return failure(); - if (!readOp.getPermutationMap().isMinorIdentity()) + if (!readOp.getPermutationMapOrMinorIdentity().isMinorIdentity()) return failure(); auto targetType = readOp.getVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -212,7 +212,7 @@ SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, - readOp.getPermutationMap(), loc, rewriter); + readOp.getPermutationMapOrMinorIdentity(), loc, rewriter); auto slicedRead = rewriter.create( loc, targetType, readOp.getSource(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), @@ -266,7 +266,7 @@ loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, - writeOp.getPermutationMap(), loc, rewriter); + writeOp.getPermutationMapOrMinorIdentity(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1681,7 +1681,7 @@ func.func @transfer_read_1d_mask(%A : memref, %base : index) -> vector<5xf32> { %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7, %m : memref, vector<5xf32> + %f = vector.transfer_read %A[%base], %f7, %m : vector<5xi1> : memref, vector<5xf32> return %f: vector<5xf32> } diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir --- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir +++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir @@ -11,7 +11,7 @@ // CHECK-SAME: %[[VAL_1:.*]]: index, // CHECK-SAME: %[[VAL_2:.*]]: vector<16xi1>) -> vector<16xf32> { // CHECK-NOT: vector.mask -// CHECK: %[[VAL_4:.*]] = vector.transfer_read {{.*}}, %[[VAL_2]] : tensor, vector<16xf32> +// CHECK: %[[VAL_4:.*]] = vector.transfer_read {{.*}}, %[[VAL_2]] : vector<16xi1>{{.*}} : tensor, vector<16xf32> // CHECK: return %[[VAL_4]] : vector<16xf32> // CHECK: } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -68,9 +68,9 @@ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x48xi8> %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : memref>, vector<5x48xi8> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5xf32> - %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref, vector<5xf32> + %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : vector<5xi1> : memref, vector<5xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5x4x8xf32> - %9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m2 {permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref, vector<5x4x8xf32> + %9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m : vector<5xi1> {permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref, vector<5x4x8xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -39,7 +39,7 @@ [0, 0, 1, 1, 1, 1, 1, 0, 1], [1, 1, 1, 1, 1, 1, 1, 0, 1], [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask : vector<4x9xi1> {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref, vector<4x9xf32> vector.print %f: vector<4x9xf32> @@ -55,7 +55,7 @@ [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]> : vector<9x4xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask : vector<9x4xi1> {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref, vector<9x4xf32> vector.print %f: vector<9x4xf32> @@ -67,7 +67,7 @@ %A : memref, %base1: index, %base2: index) { %fm42 = arith.constant -42.0: f32 %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask : vector<9xi1> {permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref, vector<4x9xf32> vector.print %f: vector<4x9xf32> @@ -79,7 +79,7 @@ %A : memref, %base1: index, %base2: index) { %fm42 = arith.constant -42.0: f32 %mask = arith.constant dense<[1, 0, 1, 1]> : vector<4xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask : vector<4xi1> {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : memref, vector<4x9xf32> vector.print %f: vector<4x9xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -52,7 +52,7 @@ %A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = arith.constant -42.0: f32 %mask = arith.constant dense<[0, 1]> : vector<2xi1> - %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask : vector<2xi1> {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>} : memref, vector<2x5x3xf32> vector.print %f: vector<2x5x3xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir @@ -20,7 +20,7 @@ func.func @transfer_read_mask_1d(%A : memref, %base: index) { %fm42 = arith.constant -42.0: f32 %m = arith.constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1> - %f = vector.transfer_read %A[%base], %fm42, %m : memref, vector<13xf32> + %f = vector.transfer_read %A[%base], %fm42, %m : vector<13xi1> : memref, vector<13xf32> vector.print %f: vector<13xf32> return } @@ -37,7 +37,7 @@ func.func @transfer_read_mask_inbounds_4(%A : memref, %base: index) { %fm42 = arith.constant -42.0: f32 %m = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1> - %f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]} + %f = vector.transfer_read %A[%base], %fm42, %m : vector<4xi1> {in_bounds = [true]} : memref, vector<4xf32> vector.print %f: vector<4xf32> return