diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -21,7 +21,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/VectorUnrollInterface.h" +#include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" // Pull in all enum type definitions and utility function declarations. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -18,7 +18,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/VectorUnrollInterface.td" +include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" def StandardOps_Dialect : Dialect { diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -19,7 +19,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/VectorUnrollInterface.h" +#include "mlir/Interfaces/VectorInterfaces.h" namespace mlir { class MLIRContext; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -15,7 +15,7 @@ include "mlir/Dialect/Affine/IR/AffineOpsBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/VectorUnrollInterface.td" +include "mlir/Interfaces/VectorInterfaces.td" def Vector_Dialect : Dialect { let name = "vector"; @@ -905,34 +905,9 @@ let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } -def Vector_TransferOpUtils { - code extraTransferDeclaration = [{ - static StringRef getMaskedAttrName() { return "masked"; } - static StringRef getPermutationMapAttrName() { return "permutation_map"; } - bool isMaskedDim(unsigned dim) { - return !masked() || - masked()->cast()[dim].cast().getValue(); - } - MemRefType getMemRefType() { - return memref().getType().cast(); - } - VectorType getVectorType() { - return vector().getType().cast(); - } - // Number of dimensions that participate in the permutation map. - unsigned getTransferRank() { - return permutation_map().getNumResults(); - } - // Number of leading dimensions that do not participate in the permutation - // map. - unsigned getLeadingMemRefRank() { - return getMemRefType().getRank() - permutation_map().getNumResults(); - } - }]; -} - def Vector_TransferReadOp : Vector_Op<"transfer_read", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyMemRef:$memref, Variadic:$indices, @@ -1090,23 +1065,12 @@ "ArrayRef maybeMasked = {}"> ]; - let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration # - [{ - /// Build the default minor identity map suitable for a vector transfer. - /// This also handles the case memref<... x vector<...>> -> vector<...> in - /// which the rank of the identity map must take the vector element type - /// into account. - static AffineMap getTransferMinorIdentityMap( - MemRefType memRefType, VectorType vectorType) { - return impl::getTransferMinorIdentityMap(memRefType, vectorType); - } - }]; - let hasFolder = 1; } def Vector_TransferWriteOp : Vector_Op<"transfer_write", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, @@ -1183,18 +1147,6 @@ "Value memref, ValueRange indices, AffineMap permutationMap">, ]; - let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration # - [{ - /// Build the default minor identity map suitable for a vector transfer. - /// This also handles the case memref<... x vector<...>> -> vector<...> in - /// which the rank of the identity map must take the vector element type - /// into account. - static AffineMap getTransferMinorIdentityMap( - MemRefType memRefType, VectorType vectorType) { - return impl::getTransferMinorIdentityMap(memRefType, vectorType); - } - }]; - let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -153,6 +153,12 @@ makePermutationMap(Operation *op, ArrayRef indices, const DenseMap &loopToVectorDim); +/// Build the default minor identity map suitable for a vector transfer. This +/// also handles the case memref<... x vector<...>> -> vector<...> in which the +/// rank of the identity map must take the vector element type into account. +AffineMap getTransferMinorIdentityMap(MemRefType memRefType, + VectorType vectorType); + namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -5,6 +5,6 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) -add_mlir_interface(VectorUnrollInterface) +add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h rename from mlir/include/mlir/Interfaces/VectorUnrollInterface.h rename to mlir/include/mlir/Interfaces/VectorInterfaces.h --- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h @@ -1,4 +1,4 @@ -//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===// +//===- VectorInterfaces.h - Vector interfaces -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,18 +6,18 @@ // //===----------------------------------------------------------------------===// // -// This file implements the operation interface for vector ops that can be -// unrolled. +// This file implements the operation interfaces for vector ops. // //===----------------------------------------------------------------------===// -#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H -#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H +#ifndef MLIR_INTERFACES_VECTORINTERFACES_H +#define MLIR_INTERFACES_VECTORINTERFACES_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" /// Include the generated interface declarations. -#include "mlir/Interfaces/VectorUnrollInterface.h.inc" +#include "mlir/Interfaces/VectorInterfaces.h.inc" -#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H +#endif // MLIR_INTERFACES_VECTORINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -0,0 +1,194 @@ +//===- VectorInterfaces.td - Vector interfaces -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for operations on vectors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VECTORINTERFACES +#define MLIR_INTERFACES_VECTORINTERFACES + +include "mlir/IR/OpBase.td" + +def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { + let description = [{ + Encodes properties of an operation on vectors that can be unrolled. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the shape ratio of unrolling to the target vector shape + `targetShape`. Return `None` if the op cannot be unrolled to the target + vector shape. + }], + /*retTy=*/"Optional>", + /*methodName=*/"getShapeForUnroll", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.getOperation()->getNumResults() == 1); + auto vt = $_op.getResult().getType(). + template dyn_cast(); + if (!vt) + return None; + SmallVector res(vt.getShape().begin(), vt.getShape().end()); + return res; + }] + >, + ]; +} + +def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { + let description = [{ + Encodes properties of an operation on vectors that can be unrolled. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod< + /*desc=*/"Return the `masked` attribute name.", + /*retTy=*/"StringRef", + /*methodName=*/"getMaskedAttrName", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ [{ return "masked"; }] + >, + StaticInterfaceMethod< + /*desc=*/"Return the `permutation_map` attribute name.", + /*retTy=*/"StringRef", + /*methodName=*/"getPermutationMapAttrName", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ [{ return "permutation_map"; }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `false` when the `masked` attribute at dimension + `dim` is set to `false`. Return `true` otherwise.}], + /*retTy=*/"bool", + /*methodName=*/"isMaskedDim", + /*args=*/(ins "unsigned":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return !$_op.masked() || + $_op.masked()->template cast()[dim] + .template cast().getValue(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the memref operand.", + /*retTy=*/"Value", + /*methodName=*/"memref", + /*args=*/(ins), + /*methodBody=*/"return $_op.memref();" + /*defaultImplementation=*/ + >, + InterfaceMethod< + /*desc=*/"Return the vector operand or result.", + /*retTy=*/"Value", + /*methodName=*/"vector", + /*args=*/(ins), + /*methodBody=*/"return $_op.vector();" + /*defaultImplementation=*/ + >, + InterfaceMethod< + /*desc=*/"Return the indices operands.", + /*retTy=*/"ValueRange", + /*methodName=*/"indices", + /*args=*/(ins), + /*methodBody=*/"return $_op.indices();" + /*defaultImplementation=*/ + >, + InterfaceMethod< + /*desc=*/"Return the permutation map.", + /*retTy=*/"AffineMap", + /*methodName=*/"permutation_map", + /*args=*/(ins), + /*methodBody=*/"return $_op.permutation_map();" + /*defaultImplementation=*/ + >, + InterfaceMethod< + /*desc=*/"Return the `masked` boolean ArrayAttr.", + /*retTy=*/"Optional", + /*methodName=*/"masked", + /*args=*/(ins), + /*methodBody=*/"return $_op.masked();" + /*defaultImplementation=*/ + >, + InterfaceMethod< + /*desc=*/"Return the MemRefType.", + /*retTy=*/"MemRefType", + /*methodName=*/"getMemRefType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ + "return $_op.memref().getType().template cast();" + >, + InterfaceMethod< + /*desc=*/"Return the VectorType.", + /*retTy=*/"VectorType", + /*methodName=*/"getVectorType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ + "return $_op.vector().getType().template cast();" + >, + InterfaceMethod< + /*desc=*/[{ Return the number of dimensions that participate in the + permutation map.}], + /*retTy=*/"unsigned", + /*methodName=*/"getTransferRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ + "return $_op.permutation_map().getNumResults();" + >, + InterfaceMethod< + /*desc=*/[{ Return the number of leading memref dimensions that do not + participate in the permutation map.}], + /*retTy=*/"unsigned", + /*methodName=*/"getLeadingMemRefRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ + "return $_op.getMemRefType().getRank() - $_op.getTransferRank();" + >, + InterfaceMethod< + /*desc=*/[{ + Helper function to account for the fact that `permutationMap` results and + `op.indices` sizes may not match and may not be aligned. The first + `getLeadingMemRefRank()` indices may just be indexed and not transferred + from/into the vector. + For example: + ``` + vector.transfer %0[%i, %j, %k, %c0] : + memref, vector<2x4xf32> + ``` + with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`. + Provide a zip function to coiterate on 2 running indices: `resultIdx` and + `indicesIdx` which accounts for this misalignment. + }], + /*retTy=*/"void", + /*methodName=*/"zipResultAndIndexing", + /*args=*/(ins "llvm::function_ref":$fun), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + for (int64_t resultIdx = 0, + indicesIdx = $_op.getLeadingMemRefRank(), + eResult = $_op.getTransferRank(); + resultIdx < eResult; + ++resultIdx, ++indicesIdx) + fun(resultIdx, indicesIdx); + }] + >, + ]; +} + +#endif // MLIR_INTERFACES_VECTORINTERFACES diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td deleted file mode 100644 --- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td +++ /dev/null @@ -1,46 +0,0 @@ -//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the interface for operations on vectors that can be unrolled. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE -#define MLIR_INTERFACES_VECTORUNROLLINTERFACE - -include "mlir/IR/OpBase.td" - -def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { - let description = [{ - Encodes properties of an operation on vectors that can be unrolled. - }]; - let cppNamespace = "::mlir"; - - let methods = [ - InterfaceMethod<[{ - Returns the shape ratio of unrolling to the target vector shape - `targetShape`. Returns `None` if the op cannot be unrolled to the target - vector shape. - }], - "Optional>", - "getShapeForUnroll", - (ins), - /*methodBody=*/[{}], - [{ - auto vt = this->getOperation()->getResult(0).getType(). - template dyn_cast(); - if (!vt) - return None; - SmallVector res(vt.getShape().begin(), vt.getShape().end()); - return res; - }] - >, - ]; -} - -#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -249,8 +249,8 @@ indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); Value memref = xferOp.memref(); - auto map = TransferReadOp::getTransferMinorIdentityMap( - xferOp.getMemRefType(), minorVectorType); + auto map = + getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType); ArrayAttr masked; if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); @@ -353,8 +353,8 @@ result = vector_extract(xferOp.vector(), majorIvs); else result = std_load(alloc, majorIvs); - auto map = TransferWriteOp::getTransferMinorIdentityMap( - xferOp.getMemRefType(), minorVectorType); + auto map = + getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType); ArrayAttr masked; if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -82,8 +82,8 @@ /// Return true if we can prove that the transfer operations access dijoint /// memory. -template -static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) { +static bool isDisjoint(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB) { if (transferA.memref() != transferB.memref()) return false; // For simplicity only look at transfer of same type. @@ -91,8 +91,8 @@ return false; unsigned rankOffset = transferA.getLeadingMemRefRank(); for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { - auto indexA = transferA.indices()[i].template getDefiningOp(); - auto indexB = transferB.indices()[i].template getDefiningOp(); + auto indexA = transferA.indices()[i].getDefiningOp(); + auto indexB = transferB.indices()[i].getDefiningOp(); // If any of the indices are dynamic we cannot prove anything. if (!indexA || !indexB) continue; @@ -100,15 +100,15 @@ if (i < rankOffset) { // For dimension used as index if we can prove that index are different we // know we are accessing disjoint slices. - if (indexA.getValue().template cast().getInt() != - indexB.getValue().template cast().getInt()) + if (indexA.getValue().cast().getInt() != + indexB.getValue().cast().getInt()) return true; } else { // For this dimension, we slice a part of the memref we need to make sure // the intervals accessed don't overlap. int64_t distance = - std::abs(indexA.getValue().template cast().getInt() - - indexB.getValue().template cast().getInt()); + std::abs(indexA.getValue().cast().getInt() - + indexB.getValue().cast().getInt()); if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) return true; } @@ -185,11 +185,17 @@ continue; if (auto transferWriteUse = dyn_cast(use.getOwner())) { - if (!isDisjoint(transferWrite, transferWriteUse)) + if (!isDisjoint( + cast(transferWrite.getOperation()), + cast( + transferWriteUse.getOperation()))) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { - if (!isDisjoint(transferWrite, transferReadUse)) + if (!isDisjoint( + cast(transferWrite.getOperation()), + cast( + transferReadUse.getOperation()))) return WalkResult::advance(); } else { // Unknown use, we cannot prove that it doesn't alias with the diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -15,7 +15,7 @@ MLIREDSC MLIRIR MLIRSideEffectInterfaces - MLIRVectorUnrollInterface + MLIRVectorInterfaces MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -19,5 +19,5 @@ MLIRSCF MLIRLoopAnalysis MLIRSideEffectInterfaces - MLIRVectorUnrollInterface + MLIRVectorInterfaces ) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1466,22 +1466,6 @@ // TransferReadOp //===----------------------------------------------------------------------===// -/// Build the default minor identity map suitable for a vector transfer. This -/// also handles the case memref<... x vector<...>> -> vector<...> in which the -/// rank of the identity map must take the vector element type into account. -AffineMap -mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType, - VectorType vectorType) { - int64_t elementVectorRank = 0; - VectorType elementVectorType = - memRefType.getElementType().dyn_cast(); - if (elementVectorType) - elementVectorRank += elementVectorType.getRank(); - return AffineMap::getMinorIdentityMap( - memRefType.getRank(), vectorType.getRank() - elementVectorRank, - memRefType.getContext()); -} - template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { @@ -1600,11 +1584,10 @@ build(builder, result, vectorType, memref, indices, permMap, maybeMasked); } -template -static void printTransferAttrs(OpAsmPrinter &p, TransferOp op) { +static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; - if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap( - op.getMemRefType(), op.getVectorType())) + if (op.permutation_map() == + getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType())) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideMasked = true; if (auto maybeMasked = op.masked()) { @@ -1623,7 +1606,7 @@ static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding(); - printTransferAttrs(p, op); + printTransferAttrs(p, cast(op.getOperation())); p << " : " << op.getMemRefType() << ", " << op.getVectorType(); } @@ -1653,8 +1636,7 @@ auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { - auto permMap = - TransferReadOp::getTransferMinorIdentityMap(memRefType, vectorType); + auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( @@ -1733,6 +1715,7 @@ int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx); int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); + return cstOp.getValue() + vectorSize <= memrefSize; } @@ -1744,23 +1727,11 @@ bool changed = false; SmallVector isMasked; isMasked.reserve(op.getTransferRank()); - // `permutationMap` results and `op.indices` sizes may not match and may not - // be aligned. The first `indicesIdx` may just be indexed and not transferred - // from/into the vector. - // For example: - // vector.transfer %0[%i, %j, %k, %c0] : memref, vector<2x4xf32> - // with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`. - // The `permutationMap` results and `op.indices` are however aligned when - // iterating in reverse until we exhaust `permutationMap` results. - // As a consequence we iterate with 2 running indices: `resultIdx` and - // `indicesIdx`, until `resultIdx` reaches 0. - for (int64_t resultIdx = permutationMap.getNumResults() - 1, - indicesIdx = op.indices().size() - 1; - resultIdx >= 0; --resultIdx, --indicesIdx) { + op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { // Already marked unmasked, nothing to see here. if (!op.isMaskedDim(resultIdx)) { isMasked.push_back(false); - continue; + return; } // Currently masked, check whether we can statically determine it is // inBounds. @@ -1768,12 +1739,11 @@ isMasked.push_back(!inBounds); // We commit the pattern if it is "more inbounds". changed |= inBounds; - } + }); if (!changed) return failure(); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - std::reverse(isMasked.begin(), isMasked.end()); op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked)); return success(); } @@ -1842,8 +1812,7 @@ auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { - auto permMap = - TransferWriteOp::getTransferMinorIdentityMap(memRefType, vectorType); + auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( @@ -1855,7 +1824,7 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) { p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" << op.indices() << "]"; - printTransferAttrs(p, op); + printTransferAttrs(p, cast(op.getOperation())); p << " : " << op.getVectorType() << ", " << op.getMemRefType(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -30,7 +30,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" -#include "mlir/Interfaces/VectorUnrollInterface.h" +#include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -243,6 +243,18 @@ return ::makePermutationMap(indices, enclosingLoopToVectorDim); } +AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType, + VectorType vectorType) { + int64_t elementVectorRank = 0; + VectorType elementVectorType = + memRefType.getElementType().dyn_cast(); + if (elementVectorType) + elementVectorRank += elementVectorType.getRank(); + return AffineMap::getMinorIdentityMap( + memRefType.getRank(), vectorType.getRank() - elementVectorRank, + memRefType.getContext()); +} + bool matcher::operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType) { // First, extract the vector type and distinguish between: @@ -257,11 +269,8 @@ bool mustDivide = false; (void)mustDivide; VectorType superVectorType; - if (auto read = dyn_cast(op)) { - superVectorType = read.getVectorType(); - mustDivide = true; - } else if (auto write = dyn_cast(op)) { - superVectorType = write.getVectorType(); + if (auto transfer = dyn_cast(op)) { + superVectorType = transfer.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { if (!isa(op)) { diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -6,7 +6,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp - VectorUnrollInterface.cpp + VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -33,6 +33,6 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) -add_mlir_interface_library(VectorUnrollInterface) +add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) diff --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp rename from mlir/lib/Interfaces/VectorUnrollInterface.cpp rename to mlir/lib/Interfaces/VectorInterfaces.cpp --- a/mlir/lib/Interfaces/VectorUnrollInterface.cpp +++ b/mlir/lib/Interfaces/VectorInterfaces.cpp @@ -1,4 +1,4 @@ -//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===// +//===- VectorInterfaces.cpp - Unrollable vector operations -*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Interfaces/VectorUnrollInterface.h" +#include "mlir/Interfaces/VectorInterfaces.h" using namespace mlir; @@ -15,4 +15,4 @@ //===----------------------------------------------------------------------===// /// Include the definitions of the VectorUntoll interfaces. -#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc" +#include "mlir/Interfaces/VectorInterfaces.cpp.inc"