diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -444,7 +444,7 @@ // Note: `ConcreteOp` corresponds to the derived operation typename. InterfaceMethod<"/*insert doc here*/", "unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{ - ConcreteOp op = cast(getOperation()); + ConcreteOp op = cast(this->getOperation()); return op.getNumInputs() + op.getNumOutputs(); }]>, ]; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -21,6 +21,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" // Pull in all enum type definitions and utility function declarations. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/VectorUnrollInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" def StandardOps_Dialect : Dialect { @@ -82,7 +83,9 @@ } class FloatUnaryOp traits = []> : - UnaryOpSameOperandAndResultType, + UnaryOpSameOperandAndResultType])>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and @@ -112,7 +115,9 @@ // i %0, %1 : i32 // class IntArithmeticOp traits = []> : - ArithmeticOp, + ArithmeticOp])>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>; // Base class for standard arithmetic binary operations on floats, vectors and @@ -125,7 +130,9 @@ // f %0, %1 : f32 // class FloatArithmeticOp traits = []> : - ArithmeticOp, + ArithmeticOp])>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; // Base class for standard arithmetic operations on complex numbers with a diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -1,4 +1,4 @@ -//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===// +//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -19,6 +19,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" namespace mlir { class MLIRContext; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Affine/IR/AffineOpsBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/VectorUnrollInterface.td" def Vector_Dialect : Dialect { let name = "vector"; @@ -39,10 +40,13 @@ // TODO(andydavis, ntv) Add an attribute to specify a different algebra // with operators other than the current set: {*, +}. def Vector_ContractionOp : - Vector_Op<"contract", [NoSideEffect, - PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, - PredOpTrait<"third operand acc and result have same element type", - TCresVTEtIsSameAsOpBase<0, 2>>]>, + Vector_Op<"contract", [ + NoSideEffect, + PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, + PredOpTrait<"third operand acc and result have same element type", + TCresVTEtIsSameAsOpBase<0, 2>>, + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, @@ -896,7 +900,9 @@ } def Vector_TransferReadOp : - Vector_Op<"transfer_read">, + Vector_Op<"transfer_read", [ + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyMemRef:$memref, Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, OptionalAttr:$masked)>, @@ -1068,7 +1074,9 @@ } def Vector_TransferWriteOp : - Vector_Op<"transfer_write">, + Vector_Op<"transfer_write", [ + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, Variadic:$indices, AffineMapAttr:$permutation_map, diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td --- a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td @@ -20,7 +20,7 @@ StrJoinInt.result # "})">; class UnrollVectorOp factors> : NativeCodeCall< - "unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " # + "unrollSingleResultVectorOp($_builder, $0.getDefiningOp(), " # "{" # StrJoinInt.result # "})">; #endif // VECTOR_TRANSFORM_PATTERNS diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -10,6 +10,8 @@ #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Function.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -25,42 +27,82 @@ namespace vector { -// Entry point for unrolling declarative pattern rewrites. -// `op` is unrolled to the `targetShape` as follows, for each of its operands: -// 1. the unrolled type `unrolledVectorType` and number of unrolled instances -// `numUnrolledInstances` are computed from the `targetShape`. For now it is -// assumed the unrolling factors divide the vector sizes. -// 2. a fakeFork cast op is inserted that takes the operand and returns -// `numUnrolledInstances` results of type `unrolledVectorType`. -// 3. the original op is cloned `numUnrolledInstances` times, once for each -// result of the fakeFork cast op. -// 4. a fakeJoin cast op takes all these results and merges them into a single -// aggregate vector result whose size matches the original non-unrolled op -// operand types. -// -// Example: -// -// opA(operand0, operand1) // numUnrolledInstances = 3 -// -// operand0 operand1 -// | | -// fork fork -// <----------gather all fork ops ---------> -// /|\ /|\ -// f00 f01 f02 f10 f11 f12 -// <---------- clone op 3 times ---------> -// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) -// \ | / -// <-------------------- join -------------------------> -// -// Other local patterns then kick in iteratively (including DCE) and compose -// until all the fakeFork and fakeJoin ops are removed. -// -// This will be extended in the future to support more advanced use cases than -// simple pointwise ops. -SmallVector -unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op, - ArrayRef targetShape); +/// Entry point for unrolling declarative pattern rewrites. +/// `op` is unrolled to the `targetShape` as follows, for each of its operands: +/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +/// `numUnrolledInstances` are computed from the `targetShape`. For now it is +/// assumed the unrolling factors divide the vector sizes. +/// 2. a fakeFork cast op is inserted that takes the operand and returns +/// `numUnrolledInstances` results of type `unrolledVectorType`. +/// 3. the original op is cloned `numUnrolledInstances` times, once for each +/// result of the fakeFork cast op. +/// 4. a fakeJoin cast op takes all these results and merges them into a +/// single aggregate vector result whose size matches the original +/// non-unrolled op operand types. +/// +/// Example: +/// +/// opA(operand0, operand1) // numUnrolledInstances = 3 +/// +/// operand0 operand1 +/// | | +/// fork fork +/// <----------gather all fork ops ---------> +/// /|\ /|\ +/// f00 f01 f02 f10 f11 f12 +/// <---------- clone op 3 times ---------> +/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +/// \ | / +/// <-------------------- join -------------------------> +/// +/// Other local patterns then kick in iteratively (including DCE) and compose +/// until all the fakeFork and fakeJoin ops are removed. +/// +/// This will be extended in the future to support more advanced use cases than +/// simple pointwise ops. +SmallVector unrollSingleResultVectorOp(OpBuilder &builder, + Operation *op, + ArrayRef targetShape); + +/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` +/// declaratively. +template +struct UnrollVectorPattern : public OpRewritePattern { + using FilterConstraintType = std::function; + UnrollVectorPattern( + ArrayRef targetShape, MLIRContext *context, + FilterConstraintType constraint = [](OpTy op) { return success(); }) + : OpRewritePattern(context), + targetShape(targetShape.begin(), targetShape.end()), + filter(constraint) {} + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(filter(op))) + return failure(); + auto unrollableVectorOp = + dyn_cast(op.getOperation()); + if (!unrollableVectorOp) + return failure(); + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + if (!maybeUnrollShape) + return failure(); + auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape); + if (!maybeShapeRatio || + llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) + return failure(); + if (op.getOperation()->getNumResults() != 1) + return failure(); + auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape); + if (resultVector.size() != 1) + return failure(); + rewriter.replaceOp(op, resultVector.front()); + return success(); + } + +private: + SmallVector targetShape; + FilterConstraintType filter; +}; } // namespace vector diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -5,5 +5,6 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) +add_mlir_interface(VectorUnrollInterface) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h @@ -0,0 +1,26 @@ +//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interface for vector ops that can be +// unrolled. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H +#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +#include "mlir/Interfaces/VectorUnrollInterface.h.inc" + +} // namespace mlir + +#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td @@ -0,0 +1,45 @@ +//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for operations on vectors that can be unrolled. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE +#define MLIR_INTERFACES_VECTORUNROLLINTERFACE + +include "mlir/IR/OpBase.td" + +def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { + let description = [{ + Encodes properties of an operation on vectors that can be unrolled. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns the shape ratio of unrolling to the target vector shape + `targetShape`. Returns `None` if the op cannot be unrolled to the target + vector shape. + }], + "Optional>", + "getShapeForUnroll", + (ins), + /*methodBody=*/[{}], + [{ + auto vt = this->getOperation()->getResult(0).getType(). + template dyn_cast(); + if (!vt) + return None; + SmallVector res(vt.getShape().begin(), vt.getShape().end()); + return res; + }] + >, + ]; +} + +#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -15,6 +15,7 @@ MLIREDSC MLIRIR MLIRSideEffectInterfaces + MLIRVectorUnrollInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -19,4 +19,5 @@ MLIRSCF MLIRLoopAnalysis MLIRSideEffectInterfaces + MLIRVectorUnrollInterface ) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -469,6 +469,12 @@ return res; } +Optional> ContractionOp::getShapeForUnroll() { + SmallVector shape; + getIterationBounds(shape); + return shape; +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -1522,6 +1528,11 @@ return OpFoldResult(); } +Optional> TransferReadOp::getShapeForUnroll() { + auto s = getVectorType().getShape(); + return SmallVector{s.begin(), s.end()}; +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -1612,6 +1623,11 @@ return foldMemRefCast(*this); } +Optional> TransferWriteOp::getShapeForUnroll() { + auto s = getVectorType().getShape(); + return SmallVector{s.begin(), s.end()}; +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -357,7 +358,7 @@ // (removable with DCE). // TODO(andydavis) Generalize this to support structured ops beyond -// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' +// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp' static Value unrollSingleResultStructuredOp(Operation *op, ArrayRef iterationBounds, std::vector &vectors, @@ -450,11 +451,7 @@ static void getVectorContractionOpUnrollState( vector::ContractionOp contractionOp, ArrayRef targetShape, - SmallVectorImpl &iterationBounds, std::vector &vectors, unsigned &resultIndex) { - // Get contraction op iteration bounds. - contractionOp.getIterationBounds(iterationBounds); - assert(iterationBounds.size() == targetShape.size()); // Get map from iteration space index to lhs/rhs/result shape index. std::vector> iterationIndexMapList; contractionOp.getIterationIndexMap(iterationIndexMapList); @@ -476,17 +473,15 @@ vectors.push_back({contractionOp.getRHSVectorMaskType(), vectors[1].indexMap, accOperandIndex + 2, false}); } - // Unroll 'op' 'iterationBounds' to 'targetShape'. // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition // 'vectors' instead of 'resultIndex'. resultIndex = accOperandIndex; } -static void -getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, - SmallVectorImpl &iterationBounds, - std::vector &vectors, - unsigned &resultIndex) { +static void getVectorElementwiseOpUnrollState(Operation *op, + ArrayRef targetShape, + std::vector &vectors, + unsigned &resultIndex) { // Verify that operation and operands all have the same vector shape. auto resultType = op->getResult(0).getType().dyn_cast_or_null(); assert(resultType && "Expected op with vector result type"); @@ -494,8 +489,6 @@ // Verify that all operands have the same vector type as result. assert(llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })); - // Populate 'iterationBounds' with 'resultShape' for elementwise operations. - iterationBounds.assign(resultShape.begin(), resultShape.end()); // Create trivial elementwise identity index map based on 'resultShape'. DenseMap indexMap; @@ -513,28 +506,32 @@ } // Entry point for unrolling declarative pattern rewrites. -SmallVector mlir::vector::unrollSingleResultOpMatchingType( - OpBuilder &builder, Operation *op, ArrayRef targetShape) { +SmallVector +mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op, + ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. SmallVector iterationBounds; + auto unrollableVectorOp = cast(op); + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + assert(maybeUnrollShape && "Trying to unroll an incorrect vector op"); + std::vector vectors; unsigned resultIndex; if (auto contractionOp = dyn_cast(op)) { // Populate state for vector ContractionOp. - getVectorContractionOpUnrollState(contractionOp, targetShape, - iterationBounds, vectors, resultIndex); + getVectorContractionOpUnrollState(contractionOp, targetShape, vectors, + resultIndex); } else { // Populate state for vector elementwise op. - getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors, - resultIndex); + getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex); } // Unroll 'op' with 'iterationBounds' to 'targetShape'. return SmallVector{unrollSingleResultStructuredOp( - op, iterationBounds, vectors, resultIndex, targetShape, builder)}; + op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)}; } /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp + VectorUnrollInterface.cpp ViewLikeInterface.cpp ) @@ -32,5 +33,6 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) +add_mlir_interface_library(VectorUnrollInterface) add_mlir_interface_library(ViewLikeInterface) diff --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorUnrollInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/VectorUnrollInterface.cpp @@ -0,0 +1,18 @@ +//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/VectorUnrollInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// VectorUnroll Interfaces +//===----------------------------------------------------------------------===// + +/// Include the definitions of the VectorUntoll interfaces. +#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc" diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -92,6 +92,20 @@ } }; +struct TestVectorUnrollingPatterns + : public PassWrapper { + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + patterns.insert>(ArrayRef{2, 2}, ctx); + patterns.insert>( + ArrayRef{2, 2, 2}, ctx); + populateVectorToVectorCanonicalizationPatterns(patterns, ctx); + populateVectorToVectorTransformationPatterns(patterns, ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace namespace mlir { @@ -107,5 +121,9 @@ PassRegistration contractionPass( "test-vector-contraction-conversion", "Test conversion patterns that lower contract ops in the vector dialect"); + + PassRegistration contractionUnrollingPass( + "test-vector-unrolling-patterns", + "Test conversion patterns to unroll contract ops in the vector dialect"); } } // namespace mlir