diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -44,21 +44,6 @@ void populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns); -/// Collect a set of vector-to-vector transformation patterns. -void populateVectorToVectorTransformationPatterns(RewritePatternSet &patterns); - -/// Collect a set of patterns to split transfer read/write ops. -/// -/// These patterns unrolls transfer read/write ops if the vector consumers/ -/// producers are extract/insert slices op. Transfer ops can map to hardware -/// load/store functionalities, where the vector size matters for bandwith -/// considerations. So these patterns should be collected separately, instead -/// of being generic canonicalization patterns. Also one can let the -/// `ignoreFilter` to return true to fail matching for fine-grained control. -void populateSplitVectorTransferPatterns( - RewritePatternSet &patterns, - std::function ignoreFilter = nullptr); - /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions @@ -74,16 +59,6 @@ /// vectors and there are more chances to share extract/insert ops. void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns); -/// Collect a set of vector slices transformation patterns: -/// ExtractSlicesOpLowering, InsertSlicesOpLowering -/// Useful for clients that want to express all vector "slices" -/// ops in terms of more elementary vector "slice" ops. If all -/// "produced" tuple values are "consumed" (the most common -/// use for "slices" ops), this lowering removes all tuple related -/// operations as well (through DCE and folding). If tuple values -/// "leak" coming in, however, some tuple related ops will remain. -void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns); - /// Collect a set of transfer read/write lowering patterns. /// /// These patterns lower transfer ops to simpler ops like `vector.load`, diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -523,63 +523,6 @@ let hasFolder = 1; } -def Vector_ExtractSlicesOp : - Vector_Op<"extract_slices", [NoSideEffect]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$sizes, - I64ArrayAttr:$strides)>, - Results<(outs TupleOf<[AnyVector]>)> { - let summary = "vector extract slices operation"; - let description = [{ - Takes an N-d vector and returns a tuple of vector slices of 'vector', - based on 'sizes' and 'strides' parameters. - - The arguments 'sizes' and 'strides' represent a specification for - generating the unrolling of 'vector' shape, which has all slices of shape - 'sizes' except for slices at dimension boundaries when 'vector' dimension - sizes are not a multiple of 'sizes'. - - Each slice is returned at the tuple element index corresponding to the - linear index of the slice w.r.t the unrolling scheme represented by 'sizes'. - Currently, only unit strides are supported. - - Example: - - ```mlir - %0 = vector.transfer_read ...: vector<4x2xf32> - - %1 = vector.extract_slices %0, [2, 2], [1, 1] - : vector<4x2xf32> into tuple, vector<2x2xf32>> - - // Example with partial slices at dimension boundaries. - %2 = vector.transfer_read ...: vector<4x3xf32> - - %3 = vector.extract_slices %2, [2, 2], [1, 1] - : vector<4x3xf32> into tuple, vector<2x1xf32>, - vector<2x2xf32>, vector<2x1xf32>> - ``` - }]; - let builders = [ - OpBuilder<(ins "TupleType":$tupleType, "Value":$vector, - "ArrayRef":$sizes, "ArrayRef":$strides)> - ]; - let extraClassDeclaration = [{ - VectorType getSourceVectorType() { - return vector().getType().cast(); - } - TupleType getResultTupleType() { - return getResult().getType().cast(); - } - void getSizes(SmallVectorImpl &results); - void getStrides(SmallVectorImpl &results); - static StringRef getSizesAttrName() { return "sizes"; } - static StringRef getStridesAttrName() { return "strides"; } - }]; - let assemblyFormat = [{ - $vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into` - type(results) - }]; -} - def Vector_ExtractMapOp : Vector_Op<"extract_map", [NoSideEffect]>, Arguments<(ins AnyVector:$vector, Variadic:$ids)>, @@ -652,7 +595,7 @@ Op, DeclareOpInterfaceMethods - ]>, + ] # ElementwiseMappable.traits>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, Results<(outs AnyVector:$result)> { let summary = "vector fused multiply-add"; @@ -769,63 +712,6 @@ let hasFolder = 1; } -def Vector_InsertSlicesOp : - Vector_Op<"insert_slices", [NoSideEffect]>, - Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes, - I64ArrayAttr:$strides)>, - Results<(outs AnyVector)> { - let summary = "vector insert slices operation"; - let description = [{ - Takes a tuple of vector slices and inserts them into the vector result - according to the 'sizes' and 'strides' parameters. - - The arguments 'sizes' and 'strides' represent a specification for - generating the unrolling of 'vector' shape, which has all slices of shape - 'sizes' except for slices at dimension boundaries when 'vector' dimension - sizes are not a multiple of 'sizes'. - - Each slice in 'vectors' is at the tuple element index corresponding to the - linear index of the slice w.r.t the unrolling scheme represented by 'sizes'. - Currently, only unit strides are supported. - - Example: - - ```mlir - %0 = vector.extract_slices %0, [2, 2], [1, 1] - : vector<4x2xf32> into tuple, vector<2x2xf32>> - - %1 = vector.insert_slices %0, [2, 2], [1, 1] - : tuple, vector<2x2xf32>> into vector<4x2xf32> - - // Example with partial slices at dimension boundaries. - %3 = vector.extract_slices %2, [2, 2], [1, 1] - : vector<4x3xf32> into tuple, vector<2x1xf32>, - vector<2x2xf32>, vector<2x1xf32>> - - %4 = vector.insert_slices %3, [2, 2], [1, 1] - : tuple, vector<2x1xf32>, - vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32> - ``` - }]; - - let extraClassDeclaration = [{ - TupleType getSourceTupleType() { - return vectors().getType().cast(); - } - VectorType getResultVectorType() { - return getResult().getType().cast(); - } - void getSizes(SmallVectorImpl &results); - void getStrides(SmallVectorImpl &results); - static StringRef getSizesAttrName() { return "sizes"; } - static StringRef getStridesAttrName() { return "strides"; } - }]; - let assemblyFormat = [{ - $vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into` - type(results) - }]; -} - def Vector_InsertMapOp : Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>, Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic:$ids)>, @@ -2011,8 +1897,8 @@ def Vector_ShapeCastOp : Vector_Op<"shape_cast", [NoSideEffect]>, - Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>, - Results<(outs AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$result)> { + Arguments<(ins AnyVector:$source)>, + Results<(outs AnyVector:$result)> { let summary = "shape_cast casts between vector shapes"; let description = [{ The shape_cast operation casts between an n-D source vector shape and @@ -2028,9 +1914,6 @@ order (i.e. 0 <= j < k). The product of all source dimension sizes and all result dimension sizes must match. - If the source/result types are a tuple of vectors, the casting operation - described above is applied to each source/result tuple element pair. - It is currently assumed that this operation does not require moving data, and that it will be folded away before lowering vector operations. @@ -2048,10 +1931,6 @@ // Example casting to a higher vector rank. %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32> - // Example casting a tuple of vectors of same rank, where tuple elements - // may have different shapes. - %5 = vector.shape_cast %4 : tuple, vector<3x3x2xf32>> to - tuple, vector<9x2xf32>> ``` }]; let extraClassDeclaration = [{ @@ -2305,43 +2184,6 @@ let hasFolder = 1; } -def Vector_TupleGetOp : - Vector_Op<"tuple_get", [NoSideEffect]>, - Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>, - Results<(outs AnyVector)> { - let summary = "vector tuple get operation"; - let description = [{ - Returns the tuple element of 'vectors' at 'index'. - - Note that this operation is used during the vector op unrolling - transformation and should be removed before lowering to lower-level - dialects. - - Example: - - ```mlir - %4 = vector.tuple %0, %1, %2, %3 - : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>> - - %5 = vector.tuple_get %4, 1 - : tuple, vector<2x1xf32>, - vector<2x2xf32>, vector<2x1xf32>> - ``` - }]; - - let extraClassDeclaration = [{ - VectorType getResultVectorType() { - return getResult().getType().cast(); - } - int64_t getIndex() { - auto index = (*this)->getAttrOfType("index"); - return index.getValue().getSExtValue(); - } - static StringRef getIndexAttrName() { return "index"; } - }]; - let hasFolder = 1; -} - def Vector_PrintOp : Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> { let summary = "print operation (for testing and debugging)"; diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -33,61 +33,6 @@ namespace vector { -/// Entry point for unrolling declarative pattern rewrites. -/// `op` is unrolled to the `targetShape` as follows, for each of its operands: -/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances -/// `numUnrolledInstances` are computed from the `targetShape`. For now it is -/// assumed the unrolling factors divide the vector sizes. -/// 2. a fakeFork cast op is inserted that takes the operand and returns -/// `numUnrolledInstances` results of type `unrolledVectorType`. -/// 3. the original op is cloned `numUnrolledInstances` times, once for each -/// result of the fakeFork cast op. -/// 4. a fakeJoin cast op takes all these results and merges them into a -/// single aggregate vector result whose size matches the original -/// non-unrolled op operand types. -/// -/// Example: -/// -/// opA(operand0, operand1) // numUnrolledInstances = 3 -/// -/// operand0 operand1 -/// | | -/// fork fork -/// <----------gather all fork ops ---------> -/// /|\ /|\ -/// f00 f01 f02 f10 f11 f12 -/// <---------- clone op 3 times ---------> -/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) -/// \ | / -/// <-------------------- join -------------------------> -/// -/// Other local patterns then kick in iteratively (including DCE) and compose -/// until all the fakeFork and fakeJoin ops are removed. -/// -/// This will be extended in the future to support more advanced use cases than -/// simple pointwise ops. -SmallVector unrollSingleResultVectorOp(OpBuilder &builder, - Operation *op, - ArrayRef targetShape); - -/// Unroll a transfer_write op. Break up the vector source into a tuple of -/// vectors matching the given shape. Then store each element with its own -/// transfer_write. If the transfer_write takes a tensor source, return the -/// unrolled Value in result. -/// -/// Example: -/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> -/// -> -/// %0 = vector.extract_slices %A, [2, 4], [1, 1] : -/// vector<4x4xf32> into tuple, vector<2x4xf32>> -/// %1 = vector.tuple_get %0, 0 : tuple, vector<2x4xf32>> -/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32> -/// %2 = vector.tuple_get %0, 1 : tuple, vector<2x4xf32>> -/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32> -LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape, - SmallVector &result); - /// Options that control the vector unrolling. struct UnrollVectorOptions { using FilterConstraintFnType = std::function; @@ -118,53 +63,39 @@ return *this; } }; -/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` -/// declaratively. -struct UnrollVectorPattern : public RewritePattern { - using FilterConstraintType = std::function; - UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - options(options) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (options.filterConstraint && failed(options.filterConstraint(op))) - return failure(); - if (!options.nativeShape) { - return op->emitError("vector unrolling expects the native shape or native" - "shape call back function to be set"); - } - auto unrollableVectorOp = dyn_cast(op); - if (!unrollableVectorOp) - return failure(); - auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); - if (!maybeUnrollShape) - return failure(); - Optional> targetShape = options.nativeShape(op); - if (!targetShape) - return op->emitError("failed to get target shape for vector unroll"); - auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); - if (!maybeShapeRatio || - llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) - return failure(); - if (isa(op)) { - SmallVector result; - if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result))) - return failure(); - rewriter.replaceOp(op, result); - return success(); - } - if (op->getNumResults() != 1) - return failure(); - auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape); - if (resultVector.size() != 1) - return failure(); - rewriter.replaceOp(op, resultVector.front()); - return success(); - } -private: - UnrollVectorOptions options; -}; +/// Collect a set of pattern to unroll vector operations to a smaller shapes. +/// `options` structure controls which operations are unrolled and the target +/// shape. +/// `op` is unrolled to the `targetShape` as follows, for each of its operands: +/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +/// `numUnrolledInstances` are computed from the `targetShape`. For now it is +/// assumed the unrolling factors divide the vector sizes. +/// 2. ExtractStridedSlice are created to break-up the vector operands. +/// 3. the original op is cloned `numUnrolledInstances` times, once for each +/// result. +/// 4. InsertStridedSlice are inserted to re-assemble the slices into the +/// original vectore shape. +/// +/// Example: +/// +/// opA(operand0, operand1) // numUnrolledInstances = 3 +/// +/// operand0 operand1 +/// | | +/// fork fork +/// <----------gather all fork ops ---------> +/// /|\ /|\ +/// f00 f01 f02 f10 f11 f12 +/// <---------- clone op 3 times ---------> +/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +/// \ | / +/// <-------------------- join -------------------------> +/// +/// Other local patterns then kick in iteratively (including DCE) and compose +/// to combine the ExtractStridedSlice/InsertStridedSlice. +void populateVectorUnrollPatterns(RewritePatternSet &patterns, + const UnrollVectorOptions &options); /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -40,12 +40,6 @@ /// Return the number of elements of basis, `0` if empty. int64_t computeMaxLinearIndex(ArrayRef basis); -/// Given a shape with sizes greater than 0 along all dimensions, -/// return the distance, in number of elements, between a slice in a dimension -/// and the next slice in the same dimension. -/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] -SmallVector computeStrides(ArrayRef shape); - /// Given the shape and sizes of a vector, returns the corresponding /// strides for each dimension. /// TODO: needs better doc of how it is used. @@ -67,12 +61,6 @@ computeElementOffsetsFromVectorSliceOffsets(ArrayRef sizes, ArrayRef vectorOffsets); -/// Given the shape, sizes, and element-space offsets of a vector, returns -/// the slize sizes for each dimension. -SmallVector computeSliceSizes(ArrayRef shape, - ArrayRef sizes, - ArrayRef elementOffsets); - /// Computes and returns the multi-dimensional ratio of `superShape` to /// `subShape`. This is calculated by performing a traversal from minor to major /// dimensions (i.e. in reverse shape order). If integral division is not diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -62,7 +62,6 @@ { RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); - populateVectorSlicesLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1192,84 +1192,12 @@ results.add(context); } -//===----------------------------------------------------------------------===// -// ExtractSlicesOp -//===----------------------------------------------------------------------===// - -void ExtractSlicesOp::build(OpBuilder &builder, OperationState &result, - TupleType tupleType, Value vector, - ArrayRef sizes, - ArrayRef strides) { - result.addOperands(vector); - auto sizesAttr = getVectorSubscriptAttr(builder, sizes); - auto stridesAttr = getVectorSubscriptAttr(builder, strides); - result.addTypes(tupleType); - result.addAttribute(getSizesAttrName(), sizesAttr); - result.addAttribute(getStridesAttrName(), stridesAttr); -} - -static LogicalResult -isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType, - TupleType tupleType, ArrayRef sizes, - ArrayRef strides) { - // Check for non-unit strides. - // TODO: Support non-1 strides. - if (llvm::any_of(strides, [](int64_t s) { return s != 1; })) - return op->emitError("requires unit strides"); - // Check that 'vectorType' rank matches rank of tuple element vectors. - unsigned rank = vectorType.getRank(); - auto is_vector_type_of_rank = [&](Type t) { - return t.isa() && t.cast().getRank() == rank; - }; - if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank)) - return op->emitError("requires vector tuple elements of rank ") << rank; - // Check that 'sizes' and 'strides' are of size == 'rank'. - if (sizes.size() != rank || strides.size() != rank) - return op->emitError("requires sizes and strides of rank ") << rank; - - // Generate each slice shape based on 'sizes', 'strides' and 'vectorType', - // and verify that the same matches the corresponding tuple element 'i'. - auto shape = vectorType.getShape(); - auto sliceStrides = computeStrides(shape, sizes); - for (int64_t i = 0, e = tupleType.size(); i < e; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); - // Create slice VectorType type. - auto sliceVectorType = - VectorType::get(sliceSizes, vectorType.getElementType()); - // Verify that 'sliceVectorType' matches tupleType.getTypes(i) - if (sliceVectorType != tupleType.getType(i)) - return op->emitError("invalid tuple element type ") << sliceVectorType; - } - return success(); -} - -static LogicalResult verify(ExtractSlicesOp op) { - SmallVector sizes; - op.getSizes(sizes); - SmallVector strides; - op.getStrides(strides); - return isValidExtractOrInsertSlicesType( - op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(), - sizes, strides); -} - static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) results.push_back(attr.cast().getInt()); } -void ExtractSlicesOp::getSizes(SmallVectorImpl &results) { - populateFromInt64AttrArray(sizes(), results); -} - -void ExtractSlicesOp::getStrides(SmallVectorImpl &results) { - populateFromInt64AttrArray(strides(), results); -} - //===----------------------------------------------------------------------===// // ExtractMapOp //===----------------------------------------------------------------------===// @@ -1620,28 +1548,6 @@ return {}; } -//===----------------------------------------------------------------------===// -// InsertSlicesOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(InsertSlicesOp op) { - SmallVector sizes; - op.getSizes(sizes); - SmallVector strides; - op.getStrides(strides); - return isValidExtractOrInsertSlicesType( - op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(), - sizes, strides); -} - -void InsertSlicesOp::getSizes(SmallVectorImpl &results) { - populateFromInt64AttrArray(sizes(), results); -} - -void InsertSlicesOp::getStrides(SmallVectorImpl &results) { - populateFromInt64AttrArray(strides(), results); -} - //===----------------------------------------------------------------------===// // InsertMapOp //===----------------------------------------------------------------------===// @@ -3441,23 +3347,6 @@ if (sourceVectorType && resultVectorType) return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); - // Check if source/result are "tuple of vectors" type. - auto sourceTupleType = op.source().getType().dyn_cast_or_null(); - auto resultTupleType = op.result().getType().dyn_cast_or_null(); - if (!sourceTupleType || !resultTupleType) - return op.emitOpError("source/result must be of same type"); - - // Check that source/result tuple sizes are the same. - if (sourceTupleType.size() != resultTupleType.size()) - return op.emitOpError("source/result tuples must be the same size"); - - // Check each source/result tuple element pair. - for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) - if (failed(verifyVectorShapeCast( - op, sourceTupleType.getType(i).cast(), - resultTupleType.getType(i).cast()))) - return failure(); - return success(); } @@ -3755,58 +3644,6 @@ populateFromInt64AttrArray(transp(), results); } -//===----------------------------------------------------------------------===// -// TupleGetOp -//===----------------------------------------------------------------------===// - -static ParseResult parseTupleGetOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operandInfo; - IntegerAttr indexAttr; - StringRef indexAttrName = TupleGetOp::getIndexAttrName(); - Type indexType = parser.getBuilder().getIndexType(); - TupleType tupleType; - if (parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(indexAttr, indexType, indexAttrName, - result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(tupleType) || - parser.resolveOperand(operandInfo, tupleType, result.operands)) - return failure(); - if (indexAttr.getInt() < 0 || - indexAttr.getInt() >= static_cast(tupleType.size())) - return failure(); - parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types); - return success(); -} - -static void print(OpAsmPrinter &p, TupleGetOp op) { - p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index(); - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()}); - p << " : " << op.getOperand().getType(); -} - -static LogicalResult verify(TupleGetOp op) { - auto tupleType = op.getOperand().getType().cast(); - if (op.getIndex() < 0 || - op.getIndex() >= static_cast(tupleType.size())) - return op.emitOpError("tuple get index out of range"); - return success(); -} - -OpFoldResult TupleGetOp::fold(ArrayRef operands) { - // Rewrite: - // %t = vector.tuple .., %e_i, .. - // %x = vector.tuple_get %t, i - // into: - // %t = vector.tuple .., %e_i, .. // one less use - // %x = %e_i - if (auto tupleOp = getOperand().getDefiningOp()) - return tupleOp.getOperand(getIndex()); - return {}; -} - //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -38,6 +38,7 @@ #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -172,809 +173,349 @@ return builder.createOperation(res); } -// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]' -// for each index 'i' in inputElements with a valid mapping in 'indexMap'. -static void getMappedElements(const DenseMap &indexMap, - ArrayRef inputElements, - SmallVectorImpl &resultElements) { - assert(indexMap.size() == resultElements.size()); - assert(inputElements.size() >= resultElements.size()); - for (unsigned i = 0, e = inputElements.size(); i < e; ++i) { - auto it = indexMap.find(i); - if (it != indexMap.end()) - resultElements[it->second] = inputElements[i]; - } -} - -// Returns a tuple type with vector element types for each resulting slice -// of 'vectorType' unrolled by 'sizes' and 'strides'. -// TODO: Move this to a utility function and share it with -// Extract/InsertSlicesOp verification. -static TupleType generateExtractSlicesOpResultType(VectorType vectorType, - ArrayRef sizes, - ArrayRef strides, - OpBuilder &builder) { - assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); - assert(static_cast(sizes.size()) == vectorType.getRank()); - assert(static_cast(strides.size()) == vectorType.getRank()); - - // Compute shape ratio of 'shape' and 'sizes'. - auto shape = vectorType.getShape(); - auto maybeDimSliceCounts = shapeRatio(shape, sizes); - assert(maybeDimSliceCounts.hasValue()); - auto sliceDimCounts = *maybeDimSliceCounts; - - // Compute strides w.r.t number of slices in each dimension. - auto sliceStrides = computeStrides(sliceDimCounts); - int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts); - SmallVector vectorTypes(sliceCount); - for (unsigned i = 0; i < sliceCount; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); - // Create Vector type and add to 'vectorTypes[i]'. - vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); - } - return builder.getTupleType(vectorTypes); -} - -// UnrolledVectorState aggregates per-operand/result vector state required for -// unrolling. -struct UnrolledVectorState { - SmallVector unrolledShape; - SmallVector unrollFactors; - SmallVector basis; - int64_t numInstances; - Value slicesTuple; -}; - -// Populates 'state' with unrolled shape, unroll factors, basis and -// num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, Value initValue, - const DenseMap &indexMap, - ArrayRef targetShape, - UnrolledVectorState &state, - OpBuilder &builder) { - // Compute unrolled shape of 'vectorType'. - state.unrolledShape.resize(vectorType.getRank()); - getMappedElements(indexMap, targetShape, state.unrolledShape); - // Compute unroll factors for unrolled shape. - auto maybeUnrollFactors = - shapeRatio(vectorType.getShape(), state.unrolledShape); - assert(maybeUnrollFactors.hasValue()); - state.unrollFactors = *maybeUnrollFactors; - // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'. - state.basis = computeStrides(state.unrollFactors); - state.numInstances = computeMaxLinearIndex(state.unrollFactors); - state.slicesTuple = nullptr; - if (initValue != nullptr) { - // Create ExtractSlicesOp. - SmallVector sizes(state.unrolledShape); - SmallVector strides(state.unrollFactors.size(), 1); - auto tupleType = - generateExtractSlicesOpResultType(vectorType, sizes, strides, builder); - state.slicesTuple = builder.create( - initValue.getLoc(), tupleType, initValue, sizes, strides); - } -} - -// Computes and returns the linear index of the unrolled vector at -// 'vectorOffsets' within the vector represented by 'state'. -static int64_t -getUnrolledVectorLinearIndex(UnrolledVectorState &state, - ArrayRef vectorOffsets, - DenseMap &indexMap) { - // Compute vector offsets. - SmallVector sliceOffsets(state.unrolledShape.size()); - getMappedElements(indexMap, vectorOffsets, sliceOffsets); - // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'. - return linearize(sliceOffsets, state.basis); -} - -// Returns an unrolled vector at 'vectorOffsets' within the vector -// represented by 'state'. The vector is created from a slice of 'initValue' -// if not present in 'cache'. -static Value getOrCreateUnrolledVectorSlice( - Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, - ArrayRef offsets, DenseMap &indexMap, - Value initValue, SmallVectorImpl &cache, OpBuilder &builder) { - // Compute slice offsets. - SmallVector sliceOffsets(state.unrolledShape.size()); - getMappedElements(indexMap, offsets, sliceOffsets); - // TODO: Support non-1 strides. - SmallVector sliceStrides(state.unrolledShape.size(), 1); - // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'. - int64_t sliceLinearIndex = - getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap); - assert(sliceLinearIndex < static_cast(cache.size())); - auto valueSlice = cache[sliceLinearIndex]; - if (valueSlice == nullptr) { - // Return tuple element at 'sliceLinearIndex'. - auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); - auto initValueType = initValue.getType().cast(); - auto vectorType = - VectorType::get(state.unrolledShape, initValueType.getElementType()); - // Initialize 'cache' with slice from 'initValue'. - valueSlice = builder.create( - loc, vectorType, state.slicesTuple, tupleIndex); - // Store value back to 'cache'. - cache[sliceLinearIndex] = valueSlice; - } - return valueSlice; -} - -// VectorState aggregates per-operand/result vector state required for -// creating slices of vector operands, and clones of the operation being -// unrolled. -struct VectorState { - // The type of this vector. - VectorType type; - // Map from iteration space index to vector dimension index. - DenseMap indexMap; - // Index of this value in operation's operand list (-1 if not an operand). - int64_t operandIndex = -1; - // Accumulator iterator flag. - bool isAcc = false; -}; - -// -// unrollSingleResultStructuredOp -// -// Returns a value representing the result of structured operation 'op' -// with iteration bounds 'iterationBounds' unrolled to 'targetShape'. -// A list of VectorState objects must be specified in 'vectors', where -// each VectorState in the list represents a vector operand or vector result -// (if the operation does not have an accumulator operand). -// The VectorState at index 'resultIndex' in the list must be the state -// associated with the operations single result (i.e. either its accumulator -// operand or vector result value). -// -// Example: -// -// // Before unrolling -// -// operand0 operand1 operand2 -// \ | / -// -------------------- opA -------------------- -// -// // After unrolling by 2 -// -// operand0 operand1 operand2 -// / \ / \ / \ -// slice00 slice01 slice10 slice11 slice20 slice21 -// \ | | | / | -// -------------------- opA0 -------------------- | -// | | | | -// \ | | / -// -------------------- opA1 ------------------- -// | | -// \ / -// insertslice -// | - -// TODO: Add the following canonicalization/simplification patterns: -// *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards -// InsertStridedSlice operand to StridedSlice. -// *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks -// if there are duplicate identical StridedSlice ops from SourceOp, and -// rewrites itself to use the first duplicate. This transformation should -// cause users of identifical StridedSlice ops to reuse the same StridedSlice -// operation, and leave the duplicate StridedSlice ops with no users -// (removable with DCE). - -// TODO: Generalize this to support structured ops beyond -// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp' -static Value unrollSingleResultStructuredOp(Operation *op, - ArrayRef iterationBounds, - std::vector &vectors, - unsigned resultIndex, - ArrayRef targetShape, - OpBuilder &builder) { - auto shapedType = op->getResult(0).getType().dyn_cast_or_null(); - if (!shapedType || !shapedType.hasStaticShape()) - assert(false && "Expected a statically shaped result type"); - - // Compute unroll factors for 'iterationBounds' based on 'targetShape' - auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape); - if (!maybeUnrollFactors.hasValue()) - assert(false && "Failed to compute unroll factors for target shape"); - auto unrollFactors = *maybeUnrollFactors; - - // Compute unrolled vector state for each vector in 'vectors'. - unsigned numVectors = vectors.size(); - SmallVector unrolledVectorState(numVectors); - for (unsigned i = 0; i < numVectors; ++i) { - int64_t operandIndex = vectors[i].operandIndex; - auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; - initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, - targetShape, unrolledVectorState[i], builder); - } - // Compute number of total unrolled instances. - auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); - auto sliceStrides = computeStrides(unrollFactors); - - auto &resultValueState = unrolledVectorState[resultIndex]; - auto unrolledResultType = VectorType::get(resultValueState.unrolledShape, - shapedType.getElementType()); - - // Initialize caches for intermediate vector results. - std::vector> caches(numVectors); - for (unsigned i = 0; i < numVectors; ++i) - caches[i].resize(unrolledVectorState[i].numInstances); - - // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'. - for (unsigned i = 0; i < numUnrolledInstances; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); - // Get cached slice (or create slice) for each operand at 'offsets'. - SmallVector operands; - operands.resize(op->getNumOperands()); - for (unsigned i = 0; i < numVectors; ++i) { - int64_t operandIndex = vectors[i].operandIndex; - if (operandIndex < 0) - continue; // Output - auto operand = op->getOperand(operandIndex); - operands[operandIndex] = getOrCreateUnrolledVectorSlice( - op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets, - vectors[i].indexMap, operand, caches[i], builder); - } - // Create op on sliced vector arguments. - auto resultVector = - cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands, - unrolledResultType) - ->getResult(0); - - // Compute linear result index. - int64_t linearIndex = getUnrolledVectorLinearIndex( - resultValueState, vectorOffsets, vectors[resultIndex].indexMap); - // Update result cache at 'linearIndex'. - caches[resultIndex][linearIndex] = resultVector; - } - - // Create TupleOp of unrolled result vectors. - SmallVector vectorTupleTypes(resultValueState.numInstances); - SmallVector vectorTupleValues(resultValueState.numInstances); - for (unsigned i = 0; i < resultValueState.numInstances; ++i) { - vectorTupleTypes[i] = caches[resultIndex][i].getType().cast(); - vectorTupleValues[i] = caches[resultIndex][i]; - } - TupleType tupleType = builder.getTupleType(vectorTupleTypes); - Value tupleOp = builder.create(op->getLoc(), tupleType, - vectorTupleValues); - - // Create InsertSlicesOp(Tuple(result_vectors)). - auto resultVectorType = op->getResult(0).getType().cast(); - SmallVector sizes(resultValueState.unrolledShape); - SmallVector strides(resultValueState.unrollFactors.size(), 1); - - Value insertSlicesOp = builder.create( - op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), - builder.getI64ArrayAttr(strides)); - return insertSlicesOp; -} - -static void getVectorContractionOpUnrollState( - vector::ContractionOp contractionOp, ArrayRef targetShape, - std::vector &vectors, unsigned &resultIndex) { - // Get map from iteration space index to lhs/rhs/result shape index. - std::vector> iterationIndexMapList; - contractionOp.getIterationIndexMap(iterationIndexMapList); - unsigned numIterators = iterationIndexMapList.size(); - vectors.resize(numIterators); - unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex(); - for (unsigned i = 0; i < numIterators; ++i) { - vectors[i].type = contractionOp.getOperand(i).getType().cast(); - vectors[i].indexMap = iterationIndexMapList[i]; - vectors[i].operandIndex = i; - vectors[i].isAcc = i == accOperandIndex ? true : false; - } - - if (llvm::size(contractionOp.masks()) == 2) { - // Add vectors for lhs/rhs vector mask arguments. Masks have the - // same vector shape lhs/rhs args, so copy their index maps. - vectors.push_back({contractionOp.getLHSVectorMaskType(), - vectors[0].indexMap, accOperandIndex + 1, false}); - vectors.push_back({contractionOp.getRHSVectorMaskType(), - vectors[1].indexMap, accOperandIndex + 2, false}); - } - // TODO: Use linalg style 'args_in'/'args_out' to partition - // 'vectors' instead of 'resultIndex'. - resultIndex = accOperandIndex; +/// Return the target shape for unrolling for the given `op`. Return llvm::None +/// if the op shouldn't be or cannot be unrolled. +static Optional> +getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { + if (options.filterConstraint && failed(options.filterConstraint(op))) + return llvm::None; + assert(options.nativeShape && + "vector unrolling expects the native shape or native" + "shape call back function to be set"); + auto unrollableVectorOp = dyn_cast(op); + if (!unrollableVectorOp) + return llvm::None; + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + if (!maybeUnrollShape) + return llvm::None; + Optional> targetShape = options.nativeShape(op); + if (!targetShape) + return llvm::None; + auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); + if (!maybeShapeRatio || + llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) + return llvm::None; + return targetShape; } -static void getVectorElementwiseOpUnrollState(Operation *op, - ArrayRef targetShape, - std::vector &vectors, - unsigned &resultIndex) { - // Verify that operation and operands all have the same vector shape. - auto resultType = op->getResult(0).getType().dyn_cast_or_null(); - assert(resultType && "Expected op with vector result type"); - auto resultShape = resultType.getShape(); - // Verify that all operands have the same vector type as result. - assert(llvm::all_of(op->getOperandTypes(), [=](Type type) { - return type.cast().getShape() == resultShape; - })); - - // Create trivial elementwise identity index map based on 'resultShape'. - DenseMap indexMap; - indexMap.reserve(resultShape.size()); - for (unsigned i = 0; i < resultShape.size(); ++i) - indexMap[i] = i; - - // Create VectorState each operand and single result. - unsigned numVectors = op->getNumOperands() + op->getNumResults(); - vectors.resize(numVectors); - for (auto it : llvm::enumerate(op->getOperandTypes())) - vectors[it.index()] = {it.value().cast(), indexMap, - static_cast(it.index()), false}; - vectors[numVectors - 1] = {resultType, indexMap, -1, false}; - resultIndex = numVectors - 1; +/// During unrolling from `originalShape` to `targetShape` return the offset for +/// the slice `index`. +static SmallVector getVectorOffset(ArrayRef originalShape, + ArrayRef targetShape, + int64_t index) { + SmallVector dstSliceStrides = + computeStrides(originalShape, targetShape); + SmallVector vectorOffsets = delinearize(dstSliceStrides, index); + SmallVector elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); + return elementOffsets; } -/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and -/// calls 'fn' with linear index and indices for each slice. -static void -generateTransferOpSlices(Type shapedElementType, VectorType vectorType, - TupleType tupleType, ArrayRef sizes, - ArrayRef strides, ArrayRef indices, - AffineMap permutationMap, OpBuilder &builder, - function_ref)> fn) { - // Compute strides w.r.t. to slice counts in each dimension. - auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); - assert(maybeDimSliceCounts.hasValue()); - auto sliceDimCounts = *maybeDimSliceCounts; - auto sliceStrides = computeStrides(sliceDimCounts); - - int64_t numSlices = tupleType.size(); - // Compute 'indexOffset' at which to update 'indices', which is equal - // to the memref rank (indices.size) minus the effective 'vectorRank'. - // The effective 'vectorRank', is equal to the rank of the vector type - // minus the rank of the memref vector element type (if it has one). - // - // For example: - // - // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector - // transfer_read/write ops which read/write vectors of type - // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective - // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1. - // - if (auto sourceVectorElementType = shapedElementType.dyn_cast()) - assert(vectorType.getRank() >= sourceVectorElementType.getRank()); +/// Compute the indices of the slice `index` for a tranfer op. +static SmallVector +sliceTransferIndices(int64_t index, ArrayRef originalShape, + ArrayRef targetShape, ArrayRef indices, + AffineMap permutationMap, Location loc, + OpBuilder &builder) { + MLIRContext *ctx = builder.getContext(); auto isBroadcast = [](AffineExpr expr) { if (auto constExpr = expr.dyn_cast()) return constExpr.getValue() == 0; return false; }; - auto *ctx = builder.getContext(); - for (unsigned i = 0; i < numSlices; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector sliceIndices(indices.begin(), indices.end()); - for (auto dim : llvm::enumerate(permutationMap.getResults())) { - if (isBroadcast(dim.value())) - continue; - unsigned pos = dim.value().cast().getPosition(); - auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(elementOffsets[dim.index()], ctx); - auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - sliceIndices[pos] = builder.create( - indices[pos].getLoc(), map, ArrayRef(indices[pos])); - } - // Call 'fn' to generate slice 'i' at 'sliceIndices'. - fn(i, sliceIndices); - } -} - -/// Unroll transfer_read ops to the given shape and create an aggregate with all -/// the chunks. -static Value unrollTransferReadOp(vector::TransferReadOp readOp, - ArrayRef targetShape, - OpBuilder &builder) { - if (readOp.mask()) - return nullptr; - auto sourceVectorType = readOp.getVectorType(); - SmallVector strides(targetShape.size(), 1); - - Location loc = readOp.getLoc(); - auto shapedElementType = - readOp.source().getType().cast().getElementType(); - auto tupleType = generateExtractSlicesOpResultType( - sourceVectorType, targetShape, strides, builder); - int64_t numSlices = tupleType.size(); - - SmallVector vectorTupleValues(numSlices); - SmallVector indices(readOp.indices().begin(), - readOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { - // Get VectorType for slice 'i'. - auto sliceVectorType = tupleType.getType(index); - // Create split TransferReadOp for 'sliceUser'. - // `in_bounds` attribute propagates conservatively: if the coarse op didn't - // need out-of-bounds masking, the fine op doesn't either. - vectorTupleValues[index] = builder.create( - loc, sliceVectorType, readOp.source(), sliceIndices, - readOp.permutation_map(), readOp.padding(), - readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr()); - }; - generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, - targetShape, strides, indices, - readOp.permutation_map(), builder, createSlice); - - // Create tuple of splice transfer read operations. - Value tupleOp = - builder.create(loc, tupleType, vectorTupleValues); - // Replace 'readOp' with result 'insertSlicesResult'. - Value newVec = builder.create( - loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape), - builder.getI64ArrayAttr(strides)); - return newVec; -} - -// Entry point for unrolling declarative pattern rewrite for transfer_write op. -LogicalResult -mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape, - SmallVector &result) { - auto writeOp = cast(op); - if (writeOp.mask()) - return failure(); - VectorType sourceVectorType = writeOp.getVectorType(); - SmallVector strides(targetShape.size(), 1); - TupleType tupleType = generateExtractSlicesOpResultType( - sourceVectorType, targetShape, strides, builder); - Location loc = writeOp.getLoc(); - Value tuple = builder.create( - loc, tupleType, writeOp.vector(), targetShape, strides); - auto shapedElementType = - writeOp.source().getType().cast().getElementType(); - SmallVector indices(writeOp.indices().begin(), - writeOp.indices().end()); - // If the TransferWrite returns a tensor, keep track of the last tensor - // created. - Value resultTensor; - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { - auto element = builder.create( - loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index)); - Operation *write = builder.create( - loc, element.getResult(), - resultTensor ? resultTensor : writeOp.source(), sliceIndices, - writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - if (!write->getResults().empty()) - resultTensor = write->getResult(0); - }; - generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, - targetShape, strides, indices, - writeOp.permutation_map(), builder, createSlice); - if (resultTensor) - result.push_back(resultTensor); - return success(); -} - -// Entry point for unrolling declarative pattern rewrites. -SmallVector -mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape) { - assert(op->getNumResults() == 1 && "Expected single result operation"); - - // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. - SmallVector iterationBounds; - auto unrollableVectorOp = cast(op); - auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); - assert(maybeUnrollShape && "Trying to unroll an incorrect vector op"); - - std::vector vectors; - unsigned resultIndex; - - if (auto readOp = dyn_cast(op)) - return SmallVector{ - unrollTransferReadOp(readOp, targetShape, builder)}; - - if (auto contractionOp = dyn_cast(op)) { - // Populate state for vector ContractionOp. - getVectorContractionOpUnrollState(contractionOp, targetShape, vectors, - resultIndex); - } else { - // Populate state for vector elementwise op. - getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex); + SmallVector elementOffsets = + getVectorOffset(originalShape, targetShape, index); + // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. + SmallVector slicedIndices(indices.begin(), indices.end()); + for (auto dim : llvm::enumerate(permutationMap.getResults())) { + if (isBroadcast(dim.value())) + continue; + unsigned pos = dim.value().cast().getPosition(); + auto expr = getAffineDimExpr(0, builder.getContext()) + + getAffineConstantExpr(elementOffsets[dim.index()], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + slicedIndices[pos] = builder.create(loc, map, indices[pos]); } - - // Unroll 'op' with 'iterationBounds' to 'targetShape'. - return SmallVector{unrollSingleResultStructuredOp( - op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)}; + return slicedIndices; } namespace { -// Splits a TransferReadOp into smaller TransferReadOps based on slicing -// scheme of its unique ExtractSlicesOp users. -class SplitTransferReadOp : public OpRewritePattern { -public: - SplitTransferReadOp(MLIRContext *context, - std::function ignoreFilter = nullptr, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {} - +struct UnrollTransferReadPattern + : public OpRewritePattern { + UnrollTransferReadPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { - if (ignoreFilter && ignoreFilter(readOp)) - return failure(); if (readOp.mask()) return failure(); - - // Return unless there is only one user, and it is an ExtractSlicesOp. - Value readResult = readOp.getResult(); - if (!readResult.hasOneUse()) - return failure(); - - auto extractSlicesOp = - dyn_cast(readResult.use_begin()->getOwner()); - if (!extractSlicesOp) - return failure(); - - // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. - SmallVector sizes; - extractSlicesOp.getSizes(sizes); - SmallVector strides; - extractSlicesOp.getStrides(strides); - assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); - - Value newVec = unrollTransferReadOp(readOp, sizes, rewriter); - if (!newVec) - return failure(); - rewriter.replaceOp(readOp, newVec); + auto targetShape = getTargetShape(options, readOp); + if (!targetShape) + return failure(); + auto sourceVectorType = readOp.getVectorType(); + SmallVector strides(targetShape->size(), 1); + Location loc = readOp.getLoc(); + ArrayRef originalSize = readOp.getVectorType().getShape(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); + // Prepare the result vector; + Value result = rewriter.create( + loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + auto targetType = + VectorType::get(*targetShape, sourceVectorType.getElementType()); + SmallVector originalIndices(readOp.indices().begin(), + readOp.indices().end()); + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector indices = + sliceTransferIndices(i, originalSize, *targetShape, originalIndices, + readOp.permutation_map(), loc, rewriter); + auto slicedRead = rewriter.create( + loc, targetType, readOp.source(), indices, readOp.permutation_map(), + readOp.padding(), + readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr()); + + SmallVector elementOffsets = + getVectorOffset(originalSize, *targetShape, i); + result = rewriter.create( + loc, slicedRead, result, elementOffsets, strides); + } + rewriter.replaceOp(readOp, result); return success(); } private: - std::function ignoreFilter; + vector::UnrollVectorOptions options; }; -// Splits a TransferWriteOp into smaller TransferWriteOps for each source. -class SplitTransferWriteOp : public OpRewritePattern { -public: - SplitTransferWriteOp(MLIRContext *context, - std::function ignoreFilter = nullptr, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {} - +struct UnrollTransferWritePattern + : public OpRewritePattern { + UnrollTransferWritePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { - if (ignoreFilter && ignoreFilter(writeOp)) - return failure(); - if (writeOp.mask()) return failure(); - - // Fail to match unless this is writing a vector resulting from an - // InsertSlicesOp. - auto insertSlicesOp = - writeOp.vector().getDefiningOp(); - if (!insertSlicesOp) - return failure(); - - // Get the TupleOp operand of the InsertSlicesOp. - auto tupleOp = insertSlicesOp.vectors().getDefiningOp(); - if (!tupleOp) + auto targetShape = getTargetShape(options, writeOp); + if (!targetShape) return failure(); - - // Get 'sizes' and 'strides' parameters from the InsertSlicesOp user. - auto sourceTupleType = insertSlicesOp.getSourceTupleType(); - auto resultVectorType = insertSlicesOp.getResultVectorType(); - SmallVector sizes; - insertSlicesOp.getSizes(sizes); - SmallVector strides; - insertSlicesOp.getStrides(strides); - + auto sourceVectorType = writeOp.getVectorType(); + SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); - auto shapedElementType = - writeOp.source().getType().cast().getElementType(); - auto indices = llvm::to_vector<4>(writeOp.indices()); + ArrayRef originalSize = sourceVectorType.getShape(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); + SmallVector originalIndices(writeOp.indices().begin(), + writeOp.indices().end()); Value resultTensor; - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { - // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. - // 'in_bounds' attribute propagates conservatively: if the coarse op - // didn't need out-of-bounds masking, the fine op doesn't either. - Operation *write = rewriter.create( - loc, tupleOp.getOperand(index), - resultTensor ? resultTensor : writeOp.source(), sliceIndices, - writeOp.permutation_map(), + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector elementOffsets = + getVectorOffset(originalSize, *targetShape, i); + Value slicedVector = rewriter.create( + loc, writeOp.vector(), elementOffsets, *targetShape, strides); + + SmallVector indices = + sliceTransferIndices(i, originalSize, *targetShape, originalIndices, + writeOp.permutation_map(), loc, rewriter); + Operation *slicedWrite = rewriter.create( + loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), + indices, writeOp.permutation_map(), writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - if (!write->getResults().empty()) - resultTensor = write->getResult(0); - }; - generateTransferOpSlices(shapedElementType, resultVectorType, - sourceTupleType, sizes, strides, indices, - writeOp.permutation_map(), rewriter, createSlice); - + // For the tensor case update the destination for the next transfer write. + if (!slicedWrite->getResults().empty()) + resultTensor = slicedWrite->getResult(0); + } if (resultTensor) - rewriter.replaceOp(writeOp, ArrayRef(resultTensor)); + rewriter.replaceOp(writeOp, resultTensor); else rewriter.eraseOp(writeOp); return success(); } private: - std::function ignoreFilter; + vector::UnrollVectorOptions options; }; -/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each -/// on vector types. -struct ShapeCastOpDecomposer : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +SmallVector permute(AffineMap map, llvm::ArrayRef source) { + SmallVector result; + result.reserve(map.getNumResults()); + for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { + unsigned dim = map.getDimPosition(i); + result.push_back(source[dim]); + } + return result; +} - LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, +struct UnrollContractionPattern + : public OpRewritePattern { + struct OffsetMapInfo { + static SmallVector getEmptyKey() { return {int64_t(-1)}; } + + static SmallVector getTombstoneKey() { return {int64_t(-2)}; } + + static unsigned getHashValue(const SmallVector &v) { + return static_cast( + llvm::hash_combine_range(v.begin(), v.end())); + } + + static bool isEqual(const SmallVector &lhs, + const SmallVector &rhs) { + return lhs == rhs; + } + }; + UnrollContractionPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - // Check if 'shapeCastOp' has tuple source/result type. - auto sourceTupleType = - shapeCastOp.source().getType().dyn_cast_or_null(); - auto resultTupleType = - shapeCastOp.result().getType().dyn_cast_or_null(); - if (!sourceTupleType || !resultTupleType) + auto targetShape = getTargetShape(options, contractOp); + if (!targetShape) return failure(); - assert(sourceTupleType.size() == resultTupleType.size()); - - // Create single-vector ShapeCastOp for each source tuple element. - Location loc = shapeCastOp.getLoc(); - SmallVector resultElements; - resultElements.reserve(resultTupleType.size()); - for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) { - auto sourceElement = rewriter.create( - loc, sourceTupleType.getType(i), shapeCastOp.source(), - rewriter.getI64IntegerAttr(i)); - resultElements.push_back(rewriter.create( - loc, resultTupleType.getType(i), sourceElement)); - } + auto dstVecType = contractOp.getResultType().cast(); + SmallVector originalSize = *contractOp.getShapeForUnroll(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); - // Replace 'shapeCastOp' with tuple of 'resultElements'. - rewriter.replaceOpWithNewOp(shapeCastOp, resultTupleType, - resultElements); + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); + Location loc = contractOp.getLoc(); + unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); + AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; + llvm::MapVector< + SmallVector, Value, + llvm::DenseMap, unsigned, OffsetMapInfo>> + accCache; + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector offsets = + getVectorOffset(originalSize, *targetShape, i); + SmallVector slicesOperands(contractOp.getNumOperands()); + + // Helper to coompute the new shape of each operand and extract the slice. + auto extractOperand = [&](unsigned index, Value operand, + AffineMap permutationMap, + ArrayRef operandOffets) { + SmallVector operandShape = + permute(permutationMap, ArrayRef(*targetShape)); + SmallVector operandStrides(operandOffets.size(), 1); + slicesOperands[index] = rewriter.create( + loc, operand, operandOffets, operandShape, operandStrides); + }; + + // Extract the new lhs operand. + AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; + SmallVector lhsOffets = + permute(lhsPermutationMap, ArrayRef(offsets)); + extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets); + // If there is a mask associated to lhs, extract it as well. + if (slicesOperands.size() > 3) + extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets); + + // Extract the new rhs operand. + AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; + SmallVector rhsOffets = + permute(rhsPermutationMap, ArrayRef(offsets)); + extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets); + // If there is a mask associated to rhs, extract it as well. + if (slicesOperands.size() > 4) + extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets); + + AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; + SmallVector accOffets = + permute(accPermutationMap, ArrayRef(offsets)); + // If a version of the accumulator has already been computed, use it + // otherwise extract the first version from the original operand. + auto accIt = accCache.find(accOffets); + if (accIt != accCache.end()) + slicesOperands[2] = accIt->second; + else + extractOperand(2, contractOp.acc(), accPermutationMap, accOffets); + + SmallVector dstShape = + permute(dstAffineMap, ArrayRef(*targetShape)); + auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, contractOp, slicesOperands, targetType); + + SmallVector dstOffets = + permute(dstAffineMap, ArrayRef(offsets)); + // Save the accumulated value untill all the loops are unrolled since + // reduction loop keep updating the accumulator. + accCache[dstOffets] = newOp->getResult(0); + } + // Assemble back the accumulator into a single vector. + Value result = rewriter.create( + loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + for (const auto &it : accCache) { + SmallVector dstStrides(it.first.size(), 1); + result = rewriter.create( + loc, it.second, result, it.first, dstStrides); + } + rewriter.replaceOp(contractOp, result); return success(); } -}; -/// Returns the producer Value of the same type as 'consumerValue', by tracking -/// the tuple index and offsets of the consumer vector value through the -/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp, -/// and ShapeCastOp) from consumer to producer. Each operation in the chain is -/// structured, and so the tuple index and offsets can be mapped from result to -/// input, while visiting each operation in the chain. -/// Returns nullptr on failure. -static Value getProducerValue(Value consumerValue) { - auto consumerVectorType = consumerValue.getType().cast(); - // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type. - int64_t tupleIndex = -1; - SmallVector offsets(consumerVectorType.getRank(), 0); - auto *op = consumerValue.getDefiningOp(); - while (op != nullptr) { - if (auto tupleGetOp = dyn_cast(op)) { - assert(tupleIndex == -1 && "TupleGetOp must have vector result type"); - - // Update 'tupleIndex' and next defining 'op' to visit. - tupleIndex = tupleGetOp.getIndex(); - op = tupleGetOp.vectors().getDefiningOp(); - } else if (auto extractSlicesOp = dyn_cast(op)) { - assert(tupleIndex >= 0); - - // Compute slice strides for 'extractSlicesOp'. - SmallVector sizes; - extractSlicesOp.getSizes(sizes); - auto sliceStrides = computeStrides( - extractSlicesOp.getSourceVectorType().getShape(), sizes); - - // Compute 'elementOffsets' into 'extractSlicesOp' input vector type, - // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'. - auto vectorOffsets = delinearize(sliceStrides, tupleIndex); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - - // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative - // to the 'extractSlicesOp' input vector type. - assert(offsets.size() == elementOffsets.size()); - for (unsigned i = 0, e = offsets.size(); i < e; ++i) - offsets[i] += elementOffsets[i]; - - // Clear 'tupleIndex' and update next defining 'op' to visit. - tupleIndex = -1; - op = extractSlicesOp.vector().getDefiningOp(); - } else if (auto insertSlicesOp = dyn_cast(op)) { - assert(tupleIndex == -1); - - // Compute slice strides for 'insertSlicesOp'. - SmallVector sizes; - insertSlicesOp.getSizes(sizes); - auto sliceStrides = computeStrides( - insertSlicesOp.getResultVectorType().getShape(), sizes); - - // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice, - // of 'insertSlicesOp' result vector type at 'offsets'. - SmallVector vectorOffsets(offsets.size()); - assert(offsets.size() == sizes.size()); - for (unsigned i = 0, e = offsets.size(); i < e; ++i) - vectorOffsets[i] = offsets[i] / sizes[i]; - - // Compute the source tuple element index. - tupleIndex = linearize(vectorOffsets, sliceStrides); - - // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now - // relative to input tuple element vector type at 'tupleIndex'. - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - assert(offsets.size() == elementOffsets.size()); - for (unsigned i = 0, e = offsets.size(); i < e; ++i) { - offsets[i] -= elementOffsets[i]; - assert(offsets[i] >= 0); - } +private: + vector::UnrollVectorOptions options; +}; - // Update next defining 'op' to visit. - op = insertSlicesOp.vectors().getDefiningOp(); - } else if (auto tupleOp = dyn_cast(op)) { - assert(tupleIndex >= 0); - - // Return tuple element 'value' at 'tupleIndex' if it matches type. - auto value = tupleOp.getOperand(tupleIndex); - if (value.getType() == consumerVectorType) - return value; - - // Update 'tupleIndex' and next defining 'op' to visit. - tupleIndex = -1; - op = value.getDefiningOp(); - } else if (auto shapeCastOp = dyn_cast(op)) { - if (shapeCastOp.source().getType().isa()) - return nullptr; - assert(tupleIndex == -1); - auto sourceVectorType = shapeCastOp.getSourceVectorType(); - auto sourceVectorShape = sourceVectorType.getShape(); - unsigned sourceVectorRank = sourceVectorType.getRank(); - auto resultVectorType = shapeCastOp.getResultVectorType(); - auto resultVectorShape = resultVectorType.getShape(); - unsigned resultVectorRank = resultVectorType.getRank(); - - int i = sourceVectorRank - 1; - int j = resultVectorRank - 1; - - // Check that source/result vector shape prefixes match while updating - // 'newOffsets'. - SmallVector newOffsets(sourceVectorRank, 0); - for (auto it : llvm::zip(llvm::reverse(sourceVectorShape), - llvm::reverse(resultVectorShape))) { - if (std::get<0>(it) != std::get<1>(it)) - return nullptr; - newOffsets[i--] = offsets[j--]; +struct UnrollElementwisePattern : public RewritePattern { + UnrollElementwisePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + options(options) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + return failure(); + auto targetShape = getTargetShape(options, op); + if (!targetShape) + return failure(); + auto dstVecType = op->getResult(0).getType().cast(); + SmallVector originalSize = + *cast(op).getShapeForUnroll(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + int64_t sliceCount = computeMaxLinearIndex(ratio); + Location loc = op->getLoc(); + // Prepare the result vector. + Value result = rewriter.create( + loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + SmallVector strides(targetShape->size(), 1); + VectorType newVecType = + VectorType::get(*targetShape, dstVecType.getElementType()); + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector offsets = + getVectorOffset(originalSize, *targetShape, i); + SmallVector extractOperands; + for (OpOperand &operand : op->getOpOperands()) { + auto vecType = operand.get().getType().template dyn_cast(); + if (!vecType) { + extractOperands.push_back(operand.get()); + continue; + } + extractOperands.push_back( + rewriter.create( + loc, operand.get(), offsets, *targetShape, strides)); } - - // Check that remaining prefix of source/result vector shapes are all 1s. - // Currently we only support producer/consumer tracking through trivial - // shape cast ops. Examples: - // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32> - // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32> - assert(i == -1 || j == -1); - if (i >= 0 && - !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i, - [](int64_t v) { return v == 1; })) - return nullptr; - if (j >= 0 && - !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j, - [](int64_t v) { return v == 1; })) - return nullptr; - - offsets.swap(newOffsets); - op = shapeCastOp.source().getDefiningOp(); - } else { - // Check if 'op' produces a Value with the same type as 'consumerValue'. - if (op->getNumResults() == 1 && - op->getResult(0).getType() == consumerVectorType) - return op->getResult(0); - return nullptr; + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, op, extractOperands, newVecType); + result = rewriter.create( + loc, newOp->getResult(0), result, offsets, strides); } + rewriter.replaceOp(op, result); + return success(); } - return nullptr; -} + +private: + vector::UnrollVectorOptions options; +}; /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // @@ -997,12 +538,6 @@ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { - // Check if we can replace 'shapeCastOp' result with its producer. - if (auto producer = getProducerValue(shapeCastOp.getResult())) { - rewriter.replaceOp(shapeCastOp, producer); - return success(); - } - // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = shapeCastOp.source().getType().dyn_cast_or_null(); @@ -1030,119 +565,6 @@ } }; -// Patter rewrite which forward tuple elements to their users. -// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))))) -// -> User(Producer) -struct TupleGetFolderOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp, - PatternRewriter &rewriter) const override { - if (auto producer = getProducerValue(tupleGetOp.getResult())) { - rewriter.replaceOp(tupleGetOp, producer); - return success(); - } - return failure(); - } -}; - -/// Progressive lowering of ExtractSlicesOp to tuple of ExtractStridedSliceOp. -/// One: -/// %x = vector.extract_slices %0 -/// is replaced by: -/// %a = vector.strided_slice %0 -/// %b = vector.strided_slice %0 -/// .. -/// %x = vector.tuple %a, %b, .. -class ExtractSlicesOpLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ExtractSlicesOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - VectorType vectorType = op.getSourceVectorType(); - auto shape = vectorType.getShape(); - - SmallVector sizes; - op.getSizes(sizes); - SmallVector strides; - op.getStrides(strides); // all-ones at the moment - - // For each element in the tuple, generate the proper strided slice. - TupleType tupleType = op.getResultTupleType(); - int64_t tupleSize = tupleType.size(); - SmallVector tupleValues(tupleSize); - auto sliceStrides = computeStrides(shape, sizes); - for (int64_t i = 0; i < tupleSize; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); - // Insert in tuple. - tupleValues[i] = rewriter.create( - loc, op.vector(), elementOffsets, sliceSizes, strides); - } - - rewriter.replaceOpWithNewOp(op, tupleType, tupleValues); - return success(); - } -}; - -/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp. -/// One: -/// %x = vector.insert_slices %0 -/// is replaced by: -/// %r0 = zero-result -/// %t1 = vector.tuple_get %0, 0 -/// %r1 = vector.insert_strided_slice %r0, %t1 -/// %t2 = vector.tuple_get %0, 1 -/// %r2 = vector.insert_strided_slice %r1, %t2 -/// .. -/// %x = .. -class InsertSlicesOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::InsertSlicesOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - VectorType vectorType = op.getResultVectorType(); - auto shape = vectorType.getShape(); - - SmallVector sizes; - op.getSizes(sizes); - SmallVector strides; - op.getStrides(strides); // all-ones at the moment - - // Prepare result. - Value result = rewriter.create( - loc, vectorType, rewriter.getZeroAttr(vectorType)); - - // For each element in the tuple, extract the proper strided slice. - TupleType tupleType = op.getSourceTupleType(); - int64_t tupleSize = tupleType.size(); - auto sliceStrides = computeStrides(shape, sizes); - for (int64_t i = 0; i < tupleSize; ++i) { - auto vectorOffsets = delinearize(sliceStrides, i); - auto elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); - // Extract from tuple into the result. - auto index = rewriter.getI64IntegerAttr(i); - auto tupleGet = rewriter.create( - loc, tupleType.getType(i), op.getOperand(), index); - result = rewriter.create( - loc, tupleGet, result, elementOffsets, strides); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - /// Progressive lowering of BroadcastOp. class BroadcastOpLowering : public OpRewritePattern { public: @@ -4194,21 +3616,6 @@ patterns.getContext(), enableIndexOptimizations); } -// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). -// TODO: Add this as DRR pattern. -void mlir::vector::populateVectorToVectorTransformationPatterns( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - -void mlir::vector::populateSplitVectorTransferPatterns( - RewritePatternSet &patterns, - std::function ignoreFilter) { - patterns.add(patterns.getContext(), - ignoreFilter); -} - void mlir::vector::populatePropagateVectorDistributionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } -void mlir::vector::populateVectorSlicesLoweringPatterns( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - void mlir::vector::populateVectorContractLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions parameters) { // clang-format off @@ -4283,3 +3684,10 @@ patterns.add(patterns.getContext()); } + +void mlir::vector::populateVectorUnrollPatterns( + RewritePatternSet &patterns, const UnrollVectorOptions &options) { + patterns.add( + patterns.getContext(), options); +} diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -48,24 +48,6 @@ std::multiplies()); } -/// Given a shape with sizes greater than 0 along all dimensions, -/// return the distance, in number of elements, between a slice in a dimension -/// and the next slice in the same dimension. -/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] -SmallVector mlir::computeStrides(ArrayRef shape) { - if (shape.empty()) - return {}; - SmallVector tmp; - tmp.reserve(shape.size()); - int64_t running = 1; - for (auto size : llvm::reverse(shape)) { - assert(size > 0 && "size must be nonnegative"); - tmp.push_back(running); - running *= size; - } - return SmallVector(tmp.rbegin(), tmp.rend()); -} - SmallVector mlir::computeStrides(ArrayRef shape, ArrayRef sizes) { int64_t rank = shape.size(); @@ -109,16 +91,6 @@ return result; } -SmallVector -mlir::computeSliceSizes(ArrayRef shape, ArrayRef sizes, - ArrayRef elementOffsets) { - int64_t rank = shape.size(); - SmallVector sliceSizes(rank); - for (unsigned r = 0; r < rank; ++r) - sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); - return sliceSizes; -} - Optional> mlir::shapeRatio(ArrayRef superShape, ArrayRef subShape) { if (superShape.size() < subShape.size()) { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1035,26 +1035,6 @@ // ----- -func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - return %1 : vector<1x1xf32> -} -// CHECK-LABEL: @extract_strides( -// CHECK-SAME: %[[ARG:.*]]: vector<3x3xf32>) -// CHECK: %[[VAL_1:.*]] = constant 0.000000e+00 : f32 -// CHECK: %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<1x1xf32> -// CHECK: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : vector<3x3xf32> to !llvm.array<3 x vector<3xf32>> -// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vector<3xf32>> -// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2] : vector<3xf32>, vector<3xf32> -// CHECK: %[[VAL_6:.*]] = llvm.mlir.cast %[[VAL_2]] : vector<1x1xf32> to !llvm.array<1 x vector<1xf32>> -// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[VAL_6]][0] : !llvm.array<1 x vector<1xf32>> -// CHECK: %[[VAL_8:.*]] = llvm.mlir.cast %[[T4]] : !llvm.array<1 x vector<1xf32>> to vector<1x1xf32> -// CHECK: return %[[VAL_8]] : vector<1x1xf32> - -// ----- - func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) { // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -864,86 +864,6 @@ %0 = vector.constant_mask [0, 2] : vector<4x3xi1> } - -// ----- - -func @extract_slices_non_unit_strides(%arg0 : vector<4x2xf32>) { - // expected-error@+1 {{requires unit strides}} - %0 = vector.extract_slices %arg0, [2, 2], [1, 3] - : vector<4x2xf32> into tuple, vector<2x2xf32>> -} - -// ----- - -func @extract_slices_tuple_element_wrong_rank(%arg0 : vector<4x2xf32>) { - // expected-error@+1 {{requires vector tuple elements of rank 2}} - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<4x2xf32> into tuple, vector<2x2x3xf32>> -} - -// ----- - -func @extract_slices_sizes_strides_wrong_rank(%arg0 : vector<4x2xf32>) { - // expected-error@+1 {{requires sizes and strides of rank}} - %0 = vector.extract_slices %arg0, [2, 2], [1, 1, 1] - : vector<4x2xf32> into tuple, vector<2x2xf32>> -} - -// ----- - -func @extract_slices_invalid_tuple_element_type(%arg0 : vector<4x2xf32>) { - // expected-error@+1 {{invalid tuple element type}} - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<4x2xf32> into tuple, vector<4x2xf32>> -} - -// ----- - -func @tuple_of_non_vectors(%arg0 : vector<4x2xf32>) { - %c0 = constant 0 : index - // expected-error@+1 {{must be vector of any type values}} - %0 = vector.tuple %arg0, %c0 : vector<4x2xf32>, index -} - -// ----- - -func @tuple_get_of_non_vectors(%arg0 : tuple, index>) { - // expected-error@+1 {{vector of any type values}} - %0 = vector.tuple_get %arg0, 0 : tuple, index> -} - -// ----- - -func @insert_slices_non_unit_strides(%arg0 : tuple, vector<2x2xf32>>) { - // expected-error@+1 {{requires unit strides}} - %0 = vector.insert_slices %arg0, [2, 2], [1, 3] - : tuple, vector<2x2xf32>> into vector<4x2xf32> -} - -// ----- - -func @insert_slices_tuple_element_wrong_rank(%arg0 : tuple, vector<2x2x3xf32>>) { - // expected-error@+1 {{requires vector tuple elements of rank 2}} - %0 = vector.insert_slices %arg0, [2, 2], [1, 1] - : tuple, vector<2x2x3xf32>> into vector<4x2xf32> -} - -// ----- - -func @insert_slices_sizes_strides_wrong_rank(%arg0 : tuple, vector<2x2xf32>>) { - // expected-error@+1 {{requires sizes and strides of rank}} - %0 = vector.insert_slices %arg0, [2, 2], [1, 1, 1] - : tuple, vector<2x2xf32>> into vector<4x2xf32> -} - -// ----- - -func @insert_slices_invalid_tuple_element_type(%arg0 : tuple, vector<4x2xf32>>) { - // expected-error@+1 {{invalid tuple element type}} - %0 = vector.insert_slices %arg0, [2, 2], [1, 1] - : tuple, vector<4x2xf32>> into vector<4x2xf32> -} - // ----- func @print_no_result(%arg0 : f32) -> i32 { @@ -1020,15 +940,6 @@ // ----- -func @shape_cast_wrong_element_type_tuple(%arg0 : tuple, - vector<3x4x2xf32>>) { - // expected-error@+1 {{op source/result vectors must have same element type}} - %0 = vector.shape_cast %arg0 : tuple, vector<3x4x2xf32>> to - tuple, vector<12x2xi32>> -} - -// ----- - func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) { // expected-error@+1 {{op source/result number of elements must match}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32> @@ -1036,15 +947,6 @@ // ----- -func @shape_cast_wrong_num_elements_tuple(%arg0 : tuple, - vector<3x4x2xf32>>) { - // expected-error@+1 {{op source/result number of elements must match}} - %0 = vector.shape_cast %arg0 : tuple, vector<3x4x2xf32>> to - tuple, vector<13x2xf32>> -} - -// ----- - func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) { // expected-error@+1 {{invalid shape cast}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32> @@ -1052,15 +954,6 @@ // ----- -func @shape_cast_invalid_rank_reduction_tuple(%arg0 - : tuple, vector<3x4x2xf32>>) { - // expected-error@+1 {{invalid shape cast}} - %0 = vector.shape_cast %arg0: tuple, vector<3x4x2xf32>> to - tuple, vector<6x4xf32>> -} - -// ----- - func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) { // expected-error@+1 {{invalid shape cast}} %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32> @@ -1068,33 +961,6 @@ // ----- -func @shape_cast_invalid_rank_expansion_tuple(%arg0 : tuple, - vector<12x2xf32>>) { - // expected-error@+1 {{invalid shape cast}} - %0 = vector.shape_cast %arg0 : tuple, vector<12x2xf32>> to - tuple, vector<4x3x2xf32>> -} - -// ----- - -func @shape_cast_source_result_different_types( - %arg1 : tuple, vector<12x2xf32>>) { - // expected-error@+1 {{source/result must be of same type}} - %1 = vector.shape_cast %arg1 : tuple, vector<12x2xf32>> to - vector<5x2x4xf32> -} - -// ----- - -func @shape_cast_different_tuple_sizes( - %arg1 : tuple, vector<3x4x2xf32>>) { - // expected-error@+1 {{op source/result tuples must be the same size}} - %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to - tuple> -} - -// ----- - func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) { // expected-error@+1 {{must be vector of any type values}} %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -334,27 +334,6 @@ return } -// CHECK-LABEL: @extract_slices -func @extract_slices(%arg0 : vector<4x2xf32>) - -> (tuple, vector<2x2xf32>>) { - // CHECK: vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<4x2xf32> into tuple, vector<2x2xf32>> - %1 = vector.tuple_get %0, 0 : tuple, vector<2x2xf32>> - %2 = vector.tuple_get %0, 1 : tuple, vector<2x2xf32>> - %3 = vector.tuple %1, %2 : vector<2x2xf32>, vector<2x2xf32> - return %3 : tuple, vector<2x2xf32>> -} - -// CHECK-LABEL: @insert_slices -func @insert_slices(%arg0 : tuple, vector<2x2xf32>>) - -> (vector<4x2xf32>) { - // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple, vector<2x2xf32>> into vector<4x2xf32> - %0 = vector.insert_slices %arg0, [2, 2], [1, 1] - : tuple, vector<2x2xf32>> into vector<4x2xf32> - return %0 : vector<4x2xf32> -} - // CHECK-LABEL: @vector_print func @vector_print(%arg0: vector<8x4xf32>) { // CHECK: vector.print %{{.*}} : vector<8x4xf32> @@ -381,28 +360,23 @@ // CHECK-LABEL: @shape_cast func @shape_cast(%arg0 : vector<5x1x3x2xf32>, - %arg1 : tuple, vector<3x4x2xf32>>, - %arg2 : vector<8x1xf32>, - %arg3 : vector<16x1x1xf32>) - -> (vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>) { + %arg1 : vector<8x1xf32>, + %arg2 : vector<16x1x1xf32>) + -> (vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>) { // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32> %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> - // CHECK-NEXT: vector.shape_cast %{{.*}} : tuple, vector<3x4x2xf32>> to tuple, vector<12x2xf32>> - %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to - tuple, vector<12x2xf32>> - // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x1xf32> to vector<8xf32> - %2 = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32> + %1 = vector.shape_cast %arg1 : vector<8x1xf32> to vector<8xf32> // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16xf32> - %3 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16xf32> + %2 = vector.shape_cast %arg2 : vector<16x1x1xf32> to vector<16xf32> // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16x1xf32> - %4 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16x1xf32> + %3 = vector.shape_cast %arg2 : vector<16x1x1xf32> to vector<16x1xf32> - return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> + return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> } // CHECK-LABEL: @bitcast diff --git a/mlir/test/Dialect/Vector/vector-slices-transforms.mlir b/mlir/test/Dialect/Vector/vector-slices-transforms.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-slices-transforms.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: mlir-opt %s -test-vector-slices-conversion | FileCheck %s - -// CHECK-LABEL: func @extract_slices(%arg0: vector<3x3xf32>) -// CHECK: %[[SS:.*]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} -// CHECK: return %[[SS]] - -func @extract_slices(%arg0: vector<3x3xf32>) -> vector<2x2xf32> { - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - %1 = vector.tuple_get %0, 0 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - return %1 : vector<2x2xf32> -} - -// CHECK-LABEL: func @insert_slices(%arg0: vector<2x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<1x2xf32>, %arg3: vector<1x1xf32>) -// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x3xf32> -// CHECK: %[[I0:.*]] = vector.insert_strided_slice %arg0, %[[C0]] {offsets = [0, 0], strides = [1, 1]} -// CHECK: %[[I1:.*]] = vector.insert_strided_slice %arg1, %[[I0]] {offsets = [0, 2], strides = [1, 1]} -// CHECK: %[[I2:.*]] = vector.insert_strided_slice %arg2, %[[I1]] {offsets = [2, 0], strides = [1, 1]} -// CHECK: %[[I3:.*]] = vector.insert_strided_slice %arg3, %[[I2]] {offsets = [2, 2], strides = [1, 1]} -// CHECK: return %[[I3]] - -func @insert_slices(%arg0: vector<2x2xf32>, - %arg1: vector<2x1xf32>, - %arg2: vector<1x2xf32>, - %arg3: vector<1x1xf32>) -> vector<3x3xf32> { - %0 = vector.tuple %arg0, %arg1, %arg2, %arg3 - : vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32> - %1 = vector.insert_slices %0, [2, 2], [1, 1] - : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32> - return %1 : vector<3x3xf32> -} - -// CHECK-LABEL: func @extract_insert_slices(%arg0: vector<3x3xf32>) -// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<3x3xf32> -// CHECK: %[[X0:.*]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} -// CHECK: %[[X1:.*]] = vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} -// CHECK: %[[X2:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [1, 2], strides = [1, 1]} -// CHECK: %[[X3:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [1, 1], strides = [1, 1]} -// CHECK: %[[X4:.*]] = vector.insert_strided_slice %[[X0]], %[[C]] {offsets = [0, 0], strides = [1, 1]} -// CHECK: %[[X5:.*]] = vector.insert_strided_slice %[[X1]], %[[X4]] {offsets = [0, 2], strides = [1, 1]} -// CHECK: %[[X6:.*]] = vector.insert_strided_slice %[[X2]], %[[X5]] {offsets = [2, 0], strides = [1, 1]} -// CHECK: %[[X7:.*]] = vector.insert_strided_slice %[[X3]], %[[X6]] {offsets = [2, 2], strides = [1, 1]} -// CHECK:return %[[X7]] - -func @extract_insert_slices(%arg0: vector<3x3xf32>) -> vector<3x3xf32> { - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - %1 = vector.insert_slices %0, [2, 2], [1, 1] - : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32> - return %1 : vector<3x3xf32> -} - -// CHECK-LABEL: func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -// CHECK: %[[X0:.*]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} -// CHECK: %[[X1:.*]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} -// CHECK: %[[X2:.*]] = vector.tuple %[[X0]], %[[X1]] -// CHECK: return %[[X2]] - -func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -> tuple, vector<2xf32>> { - %0 = vector.extract_slices %arg0, [2], [1] : vector<4xf32> into tuple, vector<2xf32>> - return %0 : tuple, vector<2xf32>> -} - diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -4,12 +4,14 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> -// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { %c0 = constant 0 : index @@ -21,15 +23,14 @@ // CHECK-LABEL: func @transfer_write_unroll // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: return func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) { @@ -63,12 +64,14 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> -// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> { %c0 = constant 0 : index @@ -80,15 +83,14 @@ // CHECK-LABEL: func @transfer_write_unroll_tensor // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> -// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> -// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> -// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[S1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[S2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[S3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>, @@ -128,14 +130,18 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32> -// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32> +// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [2, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> +// CHECK-NEXT: return %[[VEC5]] : vector<4x6xf32> #map0 = affine_map<(d0, d1) -> (d1, d0)> func @transfer_read_unroll_permutation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> { %c0 = constant 0 : index @@ -150,14 +156,18 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32> -// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32> +// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> #map0 = affine_map<(d0, d1) -> (0, d1)> func @transfer_read_unroll_broadcast(%arg0 : memref<6x4xf32>) -> vector<6x4xf32> { %c0 = constant 0 : index @@ -173,14 +183,18 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32> -// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32> +// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [2, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> +// CHECK-NEXT: return %[[VEC5]] : vector<4x6xf32> #map0 = affine_map<(d0, d1) -> (0, d0)> func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> { %c0 = constant 0 : index @@ -196,14 +210,18 @@ // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> -// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32> -// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32> +// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> #map0 = affine_map<(d0, d1, d2) -> (d2, d0)> func @transfer_read_unroll_different_rank(%arg0 : memref) -> vector<6x4xf32> { %c0 = constant 0 : index diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -3,17 +3,15 @@ // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-LABEL: func @add4x2 -// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> -// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]] : vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple, vector<2x2xf32>> into vector<4x2xf32> -// CHECK-NEXT: return %[[R2:.*]] : vector<4x2xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A1:.*]] = addf %[[S1]], %[[S2]] : vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A2:.*]] = addf %[[S3]], %[[S4]] : vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[A2]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK-NEXT: return %[[VEC1:.*]] : vector<4x2xf32> func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> { %1 = addf %0, %0: vector<4x2xf32> @@ -21,41 +19,41 @@ } // CHECK-LABEL: func @add4x4 -// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> +// CHECK-NEXT: %[[A1:.*]] = addf %[[S1]], %[[S2]] : vector<2x2xf32> -// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A3:.*]] = addf %[[TG5]], %[[TG6]] : vector<2x2xf32> +// CHECK-NEXT: %[[A2:.*]] = addf %[[S3]], %[[S4]] : vector<2x2xf32> -// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32> +// CHECK-NEXT: %[[S5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A3:.*]] = addf %[[S5]], %[[S6]] : vector<2x2xf32> -// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A4:.*]] = addf %[[S7]], %[[S8]] : vector<2x2xf32> -// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[A1]] : vector<2x2xf32> +// CHECK-NEXT: %[[S9:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A5:.*]] = addf %[[S9]], %[[A1]] : vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A5]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[A2]] : vector<2x2xf32> -// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[A3]] : vector<2x2xf32> +// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A6:.*]] = addf %[[S11]], %[[A2]] : vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A6]], %[[R1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[A4]] : vector<2x2xf32> +// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A7:.*]] = addf %[[S13]], %[[A3]] : vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.insert_strided_slice %[[A7]], %[[R2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[A8:.*]] = addf %[[S15]], %[[A4]] : vector<2x2xf32> +// CHECK-NEXT: %[[R4:.*]] = vector.insert_strided_slice %[[A8]], %[[R3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> // CHECK-NEXT: return %[[R4]] : vector<4x4xf32> func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { @@ -76,76 +74,97 @@ // CHECK-LABEL: func @contraction4x4_ijk -// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 6] : vector<4x6xi1> -// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [6, 4] : vector<6x4xi1> - // Reducing output vector [0, 0] -// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x6xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<6x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[LMASK]], [2, 2], [1, 1] : vector<4x6xi1> into tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %[[RMASK]], [2, 2], [1, 1] : vector<6x4xi1> into tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> - -// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// Reducing output vector [0, 2] +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S6]], %[[S7]], %[[R1S00]], %[[S8]], %[[S9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S10:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S12:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S13:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S10]], %[[S11]], %[[R2S00]], %[[S12]], %[[S13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// Reducing output vector [0, 2] +// CHECK-NEXT: %[[S14:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S17:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S18:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S14]], %[[S15]], %[[S16]], %[[S17]], %[[S18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[S19:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S21:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S20:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S22:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S19]], %[[S20]], %[[R1S02]], %[[S21]], %[[S22]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[S23:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S25:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S24:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S26:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S23]], %[[S24]], %[[R2S02]], %[[S25]], %[[S26]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] -// CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S27:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S30:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S28:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S31:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S29:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S27]], %[[S28]], %[[S29]], %[[S30]], %[[S31]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[S32:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S34:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S33:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S35:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S32]], %[[S33]], %[[R1S20]], %[[S34]], %[[S35]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[S36:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S38:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S37:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S39:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S36]], %[[S37]], %[[R2S20]], %[[S38]], %[[S39]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// Reducing output vector [2, 2] -// CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S40:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S43:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S41:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S44:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S42:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S40]], %[[S41]], %[[S42]], %[[S43]], %[[S44]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S45:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S47:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S46:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S48:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S45]], %[[S46]], %[[R1S22]], %[[S47]], %[[S48]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// Reducing output vector [2, 2] +// CHECK-NEXT: %[[S49:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S51:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S50:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S52:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S49]], %[[S50]], %[[R2S22]], %[[S51]], %[[S52]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R3S02]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R3S20]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[R3S22]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> -// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> +// CHECK-NEXT: return %[[VEC4]] : vector<4x4xf32> func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) @@ -170,47 +189,47 @@ // CHECK-LABEL: func @contraction4x4_ikj - -// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 2] : vector<4x2xi1> -// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [2, 4] : vector<2x4xi1> - // Reducing output vector [0, 0] -// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xf32> into tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xi1> into tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xi1> into tuple, vector<2x2xi1>> - -// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] -// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S10:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S6]], %[[S7]], %[[S8]], %[[S9]], %[[S10]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] -// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple, vector<2x2xf32>> -// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple, vector<2x2xi1>> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S14:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S12:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S15:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S11]], %[[S12]], %[[S13]], %[[S14]], %[[S15]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] -// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S19:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S17:.*]] = vector.extract_strided_slice %arg1 {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[S20:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[S18:.*]] = vector.extract_strided_slice %arg2 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S16]], %[[S17]], %[[S18]], %[[S19]], %[[S20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> -// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) @@ -303,115 +322,6 @@ return } -// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -// CHECK: return %arg1 - -func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { - %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32> - %1 = vector.tuple_get %0, 1 : tuple, vector<8xf32>> - return %1 : vector<8xf32> -} - -// CHECK-LABEL: func @tuple_get_producer_consumer -// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, -// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, -// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, -// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, -// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, -// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, -// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, -// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> -// CHECK: return %[[A7]] : vector<2x4xf32> - -func @tuple_get_producer_consumer( - %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, - %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, - %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, - %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { - %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 - : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, - vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> - // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] - %1 = vector.insert_slices %0, [2, 4], [1, 1] - : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, - vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - into vector<4x16xf32> - // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] - %2 = vector.extract_slices %1, [4, 8], [1, 1] - : vector<4x16xf32> into tuple, vector<4x8xf32>> - // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] - %3 = vector.shape_cast %2 : tuple, vector<4x8xf32>> to - tuple, vector<1x1x4x8xf32>> - // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] - %4 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> - // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4] - %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32> - // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4] - %6 = vector.extract_slices %5, [2, 4], [1, 1] - : vector<4x8xf32> into - tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0] - %7 = vector.tuple_get %6, 3 - : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %7 - return %7 : vector<2x4xf32> -} - -// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle -// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, -// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, -// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, -// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, -// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, -// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, -// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, -// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> -// CHECK: return %[[A7]] : vector<2x4xf32> - -func @tuple_get_producer_consumer_swizzle( - %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, - %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, - %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, - %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { - %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 - : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, - vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> - // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] - %1 = vector.insert_slices %0, [2, 4], [1, 1] - : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, - vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - into vector<4x16xf32> - // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] - %2 = vector.extract_slices %1, [4, 8], [1, 1] - : vector<4x16xf32> into tuple, vector<4x8xf32>> - // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] - %3= vector.shape_cast %2 : tuple, vector<4x8xf32>> to - tuple, vector<1x1x4x8xf32>> - // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] - - // Extract tuple elements. - %4 = vector.tuple_get %3, 0 : tuple, vector<1x1x4x8xf32>> - %5 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> - // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4] - - // Swizzle tuple elements. - %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32> - // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4] - %7 = vector.shape_cast %6 : tuple, vector<1x1x4x8xf32>> to - tuple, vector<4x8xf32>> - // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4] - %8 = vector.tuple_get %7, 0 : tuple, vector<4x8xf32>> - // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4] - %9 = vector.extract_slices %8, [2, 4], [1, 1] - : vector<4x8xf32> into - tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0] - %10 = vector.tuple_get %9, 3 - : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %10 - return %10 : vector<2x4xf32> -} - // CHECK-LABEL: func @cancelling_shape_cast_ops // CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32> // CHECK: return %[[A0]] : vector<2x4xf32> @@ -421,99 +331,6 @@ return %1 : vector<2x4xf32> } -// CHECK-LABEL: func @vector_transfers_vector_element_type -// CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> -// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> -// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> - -func @vector_transfers_vector_element_type() { - %c0 = constant 0 : index - %cf0 = constant 0.000000e+00 : f32 - %vf0 = splat %cf0 : vector<2x4xf32> - - %0 = memref.alloc() : memref<6x2x1xvector<2x4xf32>> - - %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0 - {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} - : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32> - - %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1] - : vector<2x1x2x4xf32> into tuple, vector<1x1x2x4xf32>> - %3 = vector.tuple_get %2, 0 : tuple, vector<1x1x2x4xf32>> - %4 = vector.tuple_get %2, 1 : tuple, vector<1x1x2x4xf32>> - %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32> - %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1] - : tuple, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32> - - vector.transfer_write %6, %0[%c0, %c0, %c0] - {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} - : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> - - return -} - -// Test that ShapeCastOp on tuple of vectors, decomposes to multiple -// ShapeCastOps on vectors. -// CHECK-LABEL: func @shape_cast_decomposition -// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32> -// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32> -// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32> - -func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>, - %arg1 : vector<3x4x2xf32>) - -> (vector<20x2xf32>, vector<12x2xf32>) { - %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> - %1 = vector.shape_cast %0 : tuple, vector<3x4x2xf32>> to - tuple, vector<12x2xf32>> - %2 = vector.tuple_get %1, 0 : tuple, vector<12x2xf32>> - %3 = vector.tuple_get %1, 1 : tuple, vector<12x2xf32>> - return %2, %3 : vector<20x2xf32>, vector<12x2xf32> -} - -// Test that cancelling ShapeCastOps are canonicalized away. -// EX: -// -// The following MLIR with cancelling ShapeCastOps: -// -// %0 = source : vector<5x4x2xf32> -// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> -// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> -// %3 = user %2 : vector<5x4x2xf32> -// -// Should canonicalize to the following: -// -// -// %0 = source : vector<5x4x2xf32> -// %1 = user %0 : vector<5x4x2xf32> -// - -// ShapeCastOps on vectors. -// CHECK-LABEL: func @shape_cast_fold -// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32> - -func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>) - -> (vector<5x4x2xf32>, vector<3x4x2xf32>) { - %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> - - %1 = vector.shape_cast %0 : tuple, vector<3x4x2xf32>> to - tuple, vector<12x2xf32>> - - %2 = vector.tuple_get %1, 0 : tuple, vector<12x2xf32>> - %3 = vector.tuple_get %1, 1 : tuple, vector<12x2xf32>> - - %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32> - %5 = vector.shape_cast %4 : tuple, vector<12x2xf32>> to - tuple, vector<3x4x2xf32>> - - %6 = vector.tuple_get %5, 0 : tuple, vector<3x4x2xf32>> - %7 = vector.tuple_get %5, 1 : tuple, vector<3x4x2xf32>> - - return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32> -} - // CHECK-LABEL: func @elementwise_unroll // CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>) // CHECK-DAG: %[[C2:.*]] = constant 2 : index diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir @@ -182,9 +182,9 @@ %9 = vector.insert %a, %8[0] : vector<2xf32> into vector<3x2xf32> %10 = vector.insert %b, %9[1] : vector<2xf32> into vector<3x2xf32> %C = vector.insert %c, %10[2] : vector<2xf32> into vector<3x2xf32> - %11 = vector.tuple %A, %B : vector<2x2xf32>, vector<2x2xf32> - %D = vector.insert_slices %11, [2, 2], [1, 1] - : tuple, vector<2x2xf32>> into vector<2x4xf32> + %cst = constant dense<0.000000e+00> : vector<2x4xf32> + %11 = vector.insert_strided_slice %A, %cst {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> + %D = vector.insert_strided_slice %B, %11 {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> vector.print %A : vector<2x2xf32> vector.print %B : vector<2x2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-extract-slices.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-extract-slices.mlir deleted file mode 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-extract-slices.mlir +++ /dev/null @@ -1,79 +0,0 @@ -// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - -func @entry() { - %f0 = constant 0.0: f32 - %f1 = constant 1.0: f32 - %f2 = constant 2.0: f32 - %f3 = constant 3.0: f32 - %f4 = constant 4.0: f32 - %f5 = constant 5.0: f32 - %f6 = constant 6.0: f32 - %f7 = constant 7.0: f32 - %f8 = constant 8.0: f32 - %f9 = constant 9.0: f32 - %f10 = constant 10.0: f32 - %f11 = constant 11.0: f32 - %f12 = constant 12.0: f32 - %f13 = constant 13.0: f32 - %f14 = constant 14.0: f32 - %f15 = constant 15.0: f32 - - %a0 = vector.broadcast %f0 : f32 to vector<4x4xf32> - %a1 = vector.insert %f0, %a0[0, 0] : f32 into vector<4x4xf32> - %a2 = vector.insert %f1, %a1[0, 1] : f32 into vector<4x4xf32> - %a3 = vector.insert %f2, %a2[0, 2] : f32 into vector<4x4xf32> - %a4 = vector.insert %f3, %a3[0, 3] : f32 into vector<4x4xf32> - %a5 = vector.insert %f4, %a4[1, 0] : f32 into vector<4x4xf32> - %a6 = vector.insert %f5, %a5[1, 1] : f32 into vector<4x4xf32> - %a7 = vector.insert %f6, %a6[1, 2] : f32 into vector<4x4xf32> - %a8 = vector.insert %f7, %a7[1, 3] : f32 into vector<4x4xf32> - %a9 = vector.insert %f8, %a8[2, 0] : f32 into vector<4x4xf32> - %a10 = vector.insert %f9, %a9[2, 1] : f32 into vector<4x4xf32> - %a11 = vector.insert %f10, %a10[2, 2] : f32 into vector<4x4xf32> - %a12 = vector.insert %f11, %a11[2, 3] : f32 into vector<4x4xf32> - %a13 = vector.insert %f12, %a12[3, 0] : f32 into vector<4x4xf32> - %a14 = vector.insert %f13, %a13[3, 1] : f32 into vector<4x4xf32> - %a15 = vector.insert %f14, %a14[3, 2] : f32 into vector<4x4xf32> - %a16 = vector.insert %f15, %a15[3, 3] : f32 into vector<4x4xf32> - - vector.print %a16 : vector<4x4xf32> - // - // test matrix: - // - // CHECK: ( ( 0, 1, 2, 3 ), ( 4, 5, 6, 7 ), ( 8, 9, 10, 11 ), ( 12, 13, 14, 15 ) ) - - // Tile 4x4 with 3x3 as follows: - // - // +--------+--+ - // +0 1 2| 3| - // |4 5 6| 7| - // |8 9 10|11| - // +--------+--+ - // |12 13 14|15| - // +--------+--+ - // - %es = vector.extract_slices %a16, [3, 3], [1, 1] : - vector<4x4xf32> into tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> - - %0 = vector.tuple_get %es, 0 : tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> - %1 = vector.tuple_get %es, 1 : tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> - %2 = vector.tuple_get %es, 2 : tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> - %3 = vector.tuple_get %es, 3 : tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> - - vector.print %0 : vector<3x3xf32> - vector.print %1 : vector<3x1xf32> - vector.print %2 : vector<1x3xf32> - vector.print %3 : vector<1x1xf32> - // - // extract slices: - // - // CHECK: ( ( 0, 1, 2 ), ( 4, 5, 6 ), ( 8, 9, 10 ) ) - // CHECK: ( ( 3 ), ( 7 ), ( 11 ) ) - // CHECK: ( ( 12, 13, 14 ) ) - // CHECK: ( ( 15 ) ) - - return -} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-insert-slices.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-insert-slices.mlir deleted file mode 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-insert-slices.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - -func @entry() { - %f0 = constant 0.0: f32 - %f1 = constant 1.0: f32 - %f2 = constant 2.0: f32 - %f3 = constant 3.0: f32 - %f4 = constant 4.0: f32 - %f5 = constant 5.0: f32 - %f6 = constant 6.0: f32 - %f7 = constant 7.0: f32 - %f8 = constant 8.0: f32 - %f9 = constant 9.0: f32 - %f10 = constant 10.0: f32 - %f11 = constant 11.0: f32 - %f12 = constant 12.0: f32 - %f13 = constant 13.0: f32 - %f14 = constant 14.0: f32 - %f15 = constant 15.0: f32 - - %a0 = vector.broadcast %f0 : f32 to vector<3x3xf32> - %a1 = vector.insert %f0, %a0[0, 0] : f32 into vector<3x3xf32> - %a2 = vector.insert %f1, %a1[0, 1] : f32 into vector<3x3xf32> - %a3 = vector.insert %f2, %a2[0, 2] : f32 into vector<3x3xf32> - %a4 = vector.insert %f4, %a3[1, 0] : f32 into vector<3x3xf32> - %a5 = vector.insert %f5, %a4[1, 1] : f32 into vector<3x3xf32> - %a6 = vector.insert %f6, %a5[1, 2] : f32 into vector<3x3xf32> - %a7 = vector.insert %f8, %a6[2, 0] : f32 into vector<3x3xf32> - %a8 = vector.insert %f9, %a7[2, 1] : f32 into vector<3x3xf32> - %a9 = vector.insert %f10, %a8[2, 2] : f32 into vector<3x3xf32> - - %b0 = vector.broadcast %f0 : f32 to vector<3x1xf32> - %b1 = vector.insert %f3, %b0[0, 0] : f32 into vector<3x1xf32> - %b2 = vector.insert %f7, %b1[1, 0] : f32 into vector<3x1xf32> - %b3 = vector.insert %f11, %b2[2, 0] : f32 into vector<3x1xf32> - - %c0 = vector.broadcast %f0 : f32 to vector<1x3xf32> - %c1 = vector.insert %f12, %c0[0, 0] : f32 into vector<1x3xf32> - %c2 = vector.insert %f13, %c1[0, 1] : f32 into vector<1x3xf32> - %c3 = vector.insert %f14, %c2[0, 2] : f32 into vector<1x3xf32> - - %d0 = vector.broadcast %f0 : f32 to vector<1x1xf32> - %d1 = vector.insert %f15, %d0[0, 0] : f32 into vector<1x1xf32> - - vector.print %a9 : vector<3x3xf32> - vector.print %b3 : vector<3x1xf32> - vector.print %c3 : vector<1x3xf32> - vector.print %d1 : vector<1x1xf32> - // - // input slices: - // - // CHECK: ( ( 0, 1, 2 ), ( 4, 5, 6 ), ( 8, 9, 10 ) ) - // CHECK: ( ( 3 ), ( 7 ), ( 11 ) ) - // CHECK: ( ( 12, 13, 14 ) ) - // CHECK: ( ( 15 ) ) - - %vt = vector.tuple %a9, %b3, %c3, %d1 : - vector<3x3xf32>, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32> - %is = vector.insert_slices %vt, [3, 3], [1, 1] : - tuple, vector<3x1xf32>, vector<1x3xf32>, vector<1x1xf32>> into vector<4x4xf32> - - vector.print %is : vector<4x4xf32> - // - // insert slices: - // - // CHECK: ( ( 0, 1, 2, 3 ), ( 4, 5, 6, 7 ), ( 8, 9, 10, 11 ), ( 12, 13, 14, 15 ) ) - - return -} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transpose.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transpose.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transpose.mlir @@ -33,9 +33,9 @@ %9 = vector.insert %a, %8[0] : vector<2xf32> into vector<3x2xf32> %10 = vector.insert %b, %9[1] : vector<2xf32> into vector<3x2xf32> %C = vector.insert %c, %10[2] : vector<2xf32> into vector<3x2xf32> - %11 = vector.tuple %A, %B : vector<2x2xf32>, vector<2x2xf32> - %D = vector.insert_slices %11, [2, 2], [1, 1] - : tuple, vector<2x2xf32>> into vector<2x4xf32> + %cst = constant dense<0.000000e+00> : vector<2x4xf32> + %11 = vector.insert_strided_slice %A, %cst {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> + %D = vector.insert_strided_slice %B, %11 {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32> vector.print %A : vector<2x2xf32> vector.print %B : vector<2x2xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -45,16 +45,14 @@ auto *ctx = &getContext(); RewritePatternSet patterns(ctx); if (unroll) { - patterns.add( - ctx, + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( filter)); } populateVectorToVectorCanonicalizationPatterns(patterns); - populateVectorToVectorTransformationPatterns(patterns); populateBubbleVectorBitCastOpPatterns(patterns); populateCastAwayVectorLeadingOneDimPatterns(patterns); - populateSplitVectorTransferPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } @@ -65,27 +63,35 @@ return SmallVector(2, 2); if (isa(op)) return SmallVector(3, 2); + // For transfer ops, just propagate the shape coming from + // InsertStridedSlices/ExtractStridedSlices. + if (auto readOp = dyn_cast(op)) { + VectorType dstVec; + for (Operation *users : readOp->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return llvm::None; + auto vecType = extract.getResult().getType().cast(); + if (dstVec && dstVec != vecType) + return llvm::None; + dstVec = vecType; + } + return SmallVector(dstVec.getShape().begin(), + dstVec.getShape().end()); + } + if (auto writeOp = dyn_cast(op)) { + auto insert = writeOp.vector().getDefiningOp(); + if (!insert) + return llvm::None; + ArrayRef shape = insert.getSourceVectorType().getShape(); + return SmallVector(shape.begin(), shape.end()); + } return llvm::None; } static LogicalResult filter(Operation *op) { - return success(isa(op)); - } -}; - -struct TestVectorSlicesConversion - : public PassWrapper { - StringRef getArgument() const final { - return "test-vector-slices-conversion"; - } - StringRef getDescription() const final { - return "Test conversion patterns that lower slices ops in the vector " - "dialect"; - } - void runOnFunction() override { - RewritePatternSet patterns(&getContext()); - populateVectorSlicesLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + return success(isa(op)); } }; @@ -178,12 +184,12 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add( - ctx, UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = @@ -199,22 +205,21 @@ } return nativeShape; }; - patterns.add( - ctx, UnrollVectorOptions() - .setNativeShapeFn(nativeShapeFn) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + populateVectorUnrollPatterns(patterns, + UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); } else { - patterns.add( - ctx, UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 2, 2}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2, 2}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); } populateVectorToVectorCanonicalizationPatterns(patterns); - populateVectorToVectorTransformationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } @@ -358,8 +363,8 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add( - ctx, + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{2, 2}) .setFilterConstraint([](Operation *op) { @@ -367,7 +372,6 @@ isa(op)); })); populateVectorToVectorCanonicalizationPatterns(patterns); - populateVectorToVectorTransformationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -463,8 +467,6 @@ void registerTestVectorConversions() { PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration();