diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -444,7 +444,7 @@ // Note: `ConcreteOp` corresponds to the derived operation typename. InterfaceMethod<"/*insert doc here*/", "unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{ - ConcreteOp op = cast(getOperation()); + ConcreteOp op = cast(this->getOperation()); return op.getNumInputs() + op.getNumOutputs(); }]>, ]; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -21,6 +21,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" // Pull in all enum type definitions and utility function declarations. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/VectorUnrollInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" def StandardOps_Dialect : Dialect { @@ -196,7 +197,8 @@ // AbsFOp //===----------------------------------------------------------------------===// -def AbsFOp : FloatUnaryOp<"absf"> { +def AbsFOp : FloatUnaryOp<"absf", + [DeclareOpInterfaceMethods]> { let summary = "floating point absolute-value operation"; let description = [{ The `absf` operation computes the absolute value. It takes one operand and @@ -222,7 +224,8 @@ // AddCFOp //===----------------------------------------------------------------------===// -def AddCFOp : ComplexFloatArithmeticOp<"addcf"> { +def AddCFOp : ComplexFloatArithmeticOp<"addcf", + [DeclareOpInterfaceMethods]> { let summary = "complex number addition"; let description = [{ The `addcf` operation takes two complex number operands and returns their @@ -242,7 +245,8 @@ // AddFOp //===----------------------------------------------------------------------===// -def AddFOp : FloatArithmeticOp<"addf"> { +def AddFOp : FloatArithmeticOp<"addf", + [DeclareOpInterfaceMethods]> { let summary = "floating point addition operation"; let description = [{ Syntax: @@ -279,7 +283,8 @@ // AddIOp //===----------------------------------------------------------------------===// -def AddIOp : IntArithmeticOp<"addi", [Commutative]> { +def AddIOp : IntArithmeticOp<"addi", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "integer addition operation"; let description = [{ Syntax: @@ -403,7 +408,8 @@ // AndOp //===----------------------------------------------------------------------===// -def AndOp : IntArithmeticOp<"and", [Commutative]> { +def AndOp : IntArithmeticOp<"and", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "integer binary and"; let description = [{ Syntax: @@ -773,7 +779,8 @@ // CeilFOp //===----------------------------------------------------------------------===// -def CeilFOp : FloatUnaryOp<"ceilf"> { +def CeilFOp : FloatUnaryOp<"ceilf", + [DeclareOpInterfaceMethods]> { let summary = "ceiling of the specified value"; let description = [{ Syntax: @@ -838,7 +845,8 @@ [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, + DeclareOpInterfaceMethods]> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float @@ -922,7 +930,8 @@ [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, + DeclareOpInterfaceMethods]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -1246,7 +1255,8 @@ // CopySignOp //===----------------------------------------------------------------------===// -def CopySignOp : FloatArithmeticOp<"copysign"> { +def CopySignOp : FloatArithmeticOp<"copysign", + [DeclareOpInterfaceMethods]> { let summary = "A copysign operation"; let description = [{ Syntax: @@ -1280,7 +1290,8 @@ // CosOp //===----------------------------------------------------------------------===// -def CosOp : FloatUnaryOp<"cos"> { +def CosOp : FloatUnaryOp<"cos", + [DeclareOpInterfaceMethods]> { let summary = "cosine of the specified value"; let description = [{ Syntax: @@ -1309,7 +1320,8 @@ }]; } -def SinOp : FloatUnaryOp<"sin"> { +def SinOp : FloatUnaryOp<"sin", + [DeclareOpInterfaceMethods]> { let summary = "sine of the specified value"; let description = [{ Syntax: @@ -1425,7 +1437,8 @@ // DivFOp //===----------------------------------------------------------------------===// -def DivFOp : FloatArithmeticOp<"divf"> { +def DivFOp : FloatArithmeticOp<"divf", + [DeclareOpInterfaceMethods]> { let summary = "floating point division operation"; } @@ -1433,7 +1446,8 @@ // ExpOp //===----------------------------------------------------------------------===// -def ExpOp : FloatUnaryOp<"exp"> { +def ExpOp : FloatUnaryOp<"exp", + [DeclareOpInterfaceMethods]> { let summary = "base-e exponential of the specified value"; let description = [{ Syntax: @@ -1465,7 +1479,8 @@ // ExpOp //===----------------------------------------------------------------------===// -def Exp2Op : FloatUnaryOp<"exp2"> { +def Exp2Op : FloatUnaryOp<"exp2", + [DeclareOpInterfaceMethods]> { let summary = "base-2 exponential of the specified value"; } @@ -1749,15 +1764,18 @@ // LogOp //===----------------------------------------------------------------------===// -def LogOp : FloatUnaryOp<"log"> { +def LogOp : FloatUnaryOp<"log", + [DeclareOpInterfaceMethods]> { let summary = "base-e logarithm of the specified value"; } -def Log10Op : FloatUnaryOp<"log10"> { +def Log10Op : FloatUnaryOp<"log10", + [DeclareOpInterfaceMethods]> { let summary = "base-10 logarithm of the specified value"; } -def Log2Op : FloatUnaryOp<"log2"> { +def Log2Op : FloatUnaryOp<"log2", + [DeclareOpInterfaceMethods]> { let summary = "base-2 logarithm of the specified value"; } @@ -1839,7 +1857,8 @@ // MulFOp //===----------------------------------------------------------------------===// -def MulFOp : FloatArithmeticOp<"mulf"> { +def MulFOp : FloatArithmeticOp<"mulf", + [DeclareOpInterfaceMethods]> { let summary = "floating point multiplication operation"; let description = [{ Syntax: @@ -1876,7 +1895,8 @@ // MulIOp //===----------------------------------------------------------------------===// -def MulIOp : IntArithmeticOp<"muli", [Commutative]> { +def MulIOp : IntArithmeticOp<"muli", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "integer multiplication operation"; let hasFolder = 1; } @@ -1885,7 +1905,8 @@ // NegFOp //===----------------------------------------------------------------------===// -def NegFOp : FloatUnaryOp<"negf"> { +def NegFOp : FloatUnaryOp<"negf", + [DeclareOpInterfaceMethods]> { let summary = "floating point negation"; let description = [{ Syntax: @@ -1918,7 +1939,8 @@ // OrOp //===----------------------------------------------------------------------===// -def OrOp : IntArithmeticOp<"or", [Commutative]> { +def OrOp : IntArithmeticOp<"or", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "integer binary or"; let description = [{ Syntax: @@ -2067,7 +2089,8 @@ // RemFOp //===----------------------------------------------------------------------===// -def RemFOp : FloatArithmeticOp<"remf"> { +def RemFOp : FloatArithmeticOp<"remf", + [DeclareOpInterfaceMethods]> { let summary = "floating point division remainder operation"; } @@ -2107,7 +2130,8 @@ // RsqrtOp //===----------------------------------------------------------------------===// -def RsqrtOp : FloatUnaryOp<"rsqrt"> { +def RsqrtOp : FloatUnaryOp<"rsqrt", + [DeclareOpInterfaceMethods]> { let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; let description = [{ The `rsqrt` operation computes the reciprocal of the square root. It takes @@ -2122,6 +2146,7 @@ //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, + DeclareOpInterfaceMethods, AllTypesMatch<["true_value", "false_value", "result"]>]> { let summary = "select operation"; let description = [{ @@ -2181,7 +2206,8 @@ // ShiftLeftOp //===----------------------------------------------------------------------===// -def ShiftLeftOp : IntArithmeticOp<"shift_left"> { +def ShiftLeftOp : IntArithmeticOp<"shift_left", + [DeclareOpInterfaceMethods]> { let summary = "integer left-shift"; let description = [{ The shift_left operation shifts an integer value to the left by a variable @@ -2201,7 +2227,8 @@ // SignedDivIOp //===----------------------------------------------------------------------===// -def SignedDivIOp : IntArithmeticOp<"divi_signed"> { +def SignedDivIOp : IntArithmeticOp<"divi_signed", + [DeclareOpInterfaceMethods]> { let summary = "signed integer division operation"; let description = [{ Syntax: @@ -2236,7 +2263,8 @@ // SignedRemIOp //===----------------------------------------------------------------------===// -def SignedRemIOp : IntArithmeticOp<"remi_signed"> { +def SignedRemIOp : IntArithmeticOp<"remi_signed", + [DeclareOpInterfaceMethods]> { let summary = "signed integer division remainder operation"; let description = [{ Syntax: @@ -2271,7 +2299,8 @@ // SignedShiftRightOp //===----------------------------------------------------------------------===// -def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { +def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed", + [DeclareOpInterfaceMethods]> { let summary = "signed integer right-shift"; let description = [{ The shift_right_signed operation shifts an integer value to the right by @@ -2408,7 +2437,8 @@ // SqrtOp //===----------------------------------------------------------------------===// -def SqrtOp : FloatUnaryOp<"sqrt"> { +def SqrtOp : FloatUnaryOp<"sqrt", + [DeclareOpInterfaceMethods]> { let summary = "sqrt of the specified value"; let description = [{ The `sqrt` operation computes the square root. It takes one operand and @@ -2523,7 +2553,8 @@ // SubFOp //===----------------------------------------------------------------------===// -def SubFOp : FloatArithmeticOp<"subf"> { +def SubFOp : FloatArithmeticOp<"subf", + [DeclareOpInterfaceMethods]> { let summary = "floating point subtraction operation"; let hasFolder = 1; } @@ -2532,7 +2563,8 @@ // SubIOp //===----------------------------------------------------------------------===// -def SubIOp : IntArithmeticOp<"subi"> { +def SubIOp : IntArithmeticOp<"subi", + [DeclareOpInterfaceMethods]> { let summary = "integer subtraction operation"; let hasFolder = 1; } @@ -2817,7 +2849,8 @@ // TanhOp //===----------------------------------------------------------------------===// -def TanhOp : FloatUnaryOp<"tanh"> { +def TanhOp : FloatUnaryOp<"tanh", + [DeclareOpInterfaceMethods]> { let summary = "hyperbolic tangent of the specified value"; let description = [{ Syntax: @@ -3017,7 +3050,8 @@ // UnsignedDivIOp //===----------------------------------------------------------------------===// -def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { +def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned", + [DeclareOpInterfaceMethods]> { let summary = "unsigned integer division operation"; let description = [{ Syntax: @@ -3052,7 +3086,8 @@ // UnsignedRemIOp //===----------------------------------------------------------------------===// -def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { +def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned", + [DeclareOpInterfaceMethods]> { let summary = "unsigned integer division remainder operation"; let description = [{ Syntax: @@ -3087,7 +3122,8 @@ // UnsignedShiftRightOp //===----------------------------------------------------------------------===// -def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { +def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned", + [DeclareOpInterfaceMethods]> { let summary = "unsigned integer right-shift"; let description = [{ The shift_right_unsigned operation shifts an integer value to the right by @@ -3174,7 +3210,8 @@ // XOrOp //===----------------------------------------------------------------------===// -def XOrOp : IntArithmeticOp<"xor", [Commutative]> { +def XOrOp : IntArithmeticOp<"xor", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "integer binary xor"; let description = [{ The `xor` operation takes two operands and returns one result, each of these diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -19,6 +19,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" namespace mlir { class MLIRContext; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Affine/IR/AffineOpsBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/VectorUnrollInterface.td" def Vector_Dialect : Dialect { let name = "vector"; @@ -39,10 +40,13 @@ // TODO(andydavis, ntv) Add an attribute to specify a different algebra // with operators other than the current set: {*, +}. def Vector_ContractionOp : - Vector_Op<"contract", [NoSideEffect, - PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, - PredOpTrait<"third operand acc and result have same element type", - TCresVTEtIsSameAsOpBase<0, 2>>]>, + Vector_Op<"contract", [ + NoSideEffect, + PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, + PredOpTrait<"third operand acc and result have same element type", + TCresVTEtIsSameAsOpBase<0, 2>>, + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, @@ -896,7 +900,9 @@ } def Vector_TransferReadOp : - Vector_Op<"transfer_read">, + Vector_Op<"transfer_read", [ + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyMemRef:$memref, Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, OptionalAttr:$masked)>, @@ -1068,7 +1074,9 @@ } def Vector_TransferWriteOp : - Vector_Op<"transfer_write">, + Vector_Op<"transfer_write", [ + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, Variadic:$indices, AffineMapAttr:$permutation_map, diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -10,6 +10,8 @@ #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Function.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -25,43 +27,82 @@ namespace vector { -// Entry point for unrolling declarative pattern rewrites. -// `op` is unrolled to the `targetShape` as follows, for each of its operands: -// 1. the unrolled type `unrolledVectorType` and number of unrolled instances -// `numUnrolledInstances` are computed from the `targetShape`. For now it is -// assumed the unrolling factors divide the vector sizes. -// 2. a fakeFork cast op is inserted that takes the operand and returns -// `numUnrolledInstances` results of type `unrolledVectorType`. -// 3. the original op is cloned `numUnrolledInstances` times, once for each -// result of the fakeFork cast op. -// 4. a fakeJoin cast op takes all these results and merges them into a single -// aggregate vector result whose size matches the original non-unrolled op -// operand types. -// -// Example: -// -// opA(operand0, operand1) // numUnrolledInstances = 3 -// -// operand0 operand1 -// | | -// fork fork -// <----------gather all fork ops ---------> -// /|\ /|\ -// f00 f01 f02 f10 f11 f12 -// <---------- clone op 3 times ---------> -// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) -// \ | / -// <-------------------- join -------------------------> -// -// Other local patterns then kick in iteratively (including DCE) and compose -// until all the fakeFork and fakeJoin ops are removed. -// -// This will be extended in the future to support more advanced use cases than -// simple pointwise ops. +/// Entry point for unrolling declarative pattern rewrites. +/// `op` is unrolled to the `targetShape` as follows, for each of its operands: +/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +/// `numUnrolledInstances` are computed from the `targetShape`. For now it is +/// assumed the unrolling factors divide the vector sizes. +/// 2. a fakeFork cast op is inserted that takes the operand and returns +/// `numUnrolledInstances` results of type `unrolledVectorType`. +/// 3. the original op is cloned `numUnrolledInstances` times, once for each +/// result of the fakeFork cast op. +/// 4. a fakeJoin cast op takes all these results and merges them into a +/// single aggregate vector result whose size matches the original +/// non-unrolled op operand types. +/// +/// Example: +/// +/// opA(operand0, operand1) // numUnrolledInstances = 3 +/// +/// operand0 operand1 +/// | | +/// fork fork +/// <----------gather all fork ops ---------> +/// /|\ /|\ +/// f00 f01 f02 f10 f11 f12 +/// <---------- clone op 3 times ---------> +/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +/// \ | / +/// <-------------------- join -------------------------> +/// +/// Other local patterns then kick in iteratively (including DCE) and compose +/// until all the fakeFork and fakeJoin ops are removed. +/// +/// This will be extended in the future to support more advanced use cases than +/// simple pointwise ops. SmallVector unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op, ArrayRef targetShape); +/// Pattern to apply `unrollSingleResultOpMatchingType` to a `targetShape` +/// declaratively. + +template +struct UnrollVectorPattern : public OpRewritePattern { + using FilterConstraintType = std::function; + UnrollVectorPattern( + ArrayRef targetShape, MLIRContext *context, + FilterConstraintType constraint = [](OpTy op) { return success(); }) + : OpRewritePattern(context), targetShape(targetShape), + filter(constraint) {} + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(filter(op))) + return failure(); + auto unrollableVectorOp = + dyn_cast(op.getOperation()); + if (!unrollableVectorOp) + return failure(); + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + if (!maybeUnrollShape) + return failure(); + auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape); + if (!maybeShapeRatio || + llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) + return failure(); + auto resultVector = + unrollSingleResultOpMatchingType(rewriter, op, targetShape); + if (resultVector.size() != 1) + return failure(); + rewriter.replaceOp(op, resultVector.front()); + return success(); + } + +private: + ArrayRef targetShape; + FilterConstraintType filter; +}; + } // namespace vector //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h @@ -0,0 +1,25 @@ +//===- ViewLikeInterface.h - View-like operations interface ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interface for view-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H +#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +#include "mlir/Interfaces/VectorUnrollInterface.h.inc" + +} // namespace mlir + +#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td @@ -0,0 +1,45 @@ +//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for operations on vectors that can be unrolled. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE +#define MLIR_INTERFACES_VECTORUNROLLINTERFACE + +include "mlir/IR/OpBase.td" + +def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { + let description = [{ + Encodes properties of an operation on vectors that can be unrolled. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns the shape ratio of unrolling to the target vector shape + `targetShape`. Returns `None` if the op cannot be unrolled to the target + vector shape. + }], + "Optional>", + "getShapeForUnroll", + (ins), + /*methodBody=*/[{}], + [{ + auto vt = this->getOperation()->getResult(0).getType(). + template dyn_cast(); + if (!vt) + return None; + SmallVector res(vt.getShape().begin(), vt.getShape().end()); + return res; + }] + >, + ]; +} + +#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -469,6 +469,12 @@ return res; } +Optional> ContractionOp::getShapeForUnroll() { + SmallVector shape; + getIterationBounds(shape); + return shape; +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -1522,6 +1528,11 @@ return OpFoldResult(); } +Optional> TransferReadOp::getShapeForUnroll() { + auto s = getVectorType().getShape(); + return SmallVector{s.begin(), s.end()}; +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -1612,6 +1623,11 @@ return foldMemRefCast(*this); } +Optional> TransferWriteOp::getShapeForUnroll() { + auto s = getVectorType().getShape(); + return SmallVector{s.begin(), s.end()}; +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/VectorUnrollInterface.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorUnrollInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/VectorUnrollInterface.cpp @@ -0,0 +1,18 @@ +//===- VectorUnrollInterface.cpp - Unrollable vector operations in MLIR ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/VectorUnrollInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// VectorUnroll Interfaces +//===----------------------------------------------------------------------===// + +/// Include the definitions of the VectorUntoll interfaces. +#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc" diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -84,6 +84,20 @@ } }; +struct TestVectorUnrollingPatterns + : public PassWrapper { + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + patterns.insert>(ArrayRef{2, 2}, ctx); + patterns.insert>( + ArrayRef{2, 2, 2}, ctx); + populateVectorToVectorCanonicalizationPatterns(patterns, ctx); + populateVectorToVectorTransformationPatterns(patterns, ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace namespace mlir { @@ -99,5 +113,9 @@ PassRegistration contractionPass( "test-vector-contraction-conversion", "Test conversion patterns that lower contract ops in the vector dialect"); + + PassRegistration contractionUnrollingPass( + "test-vector-unrolling-patterns", + "Test conversion patterns to unroll contract ops in the vector dialect"); } } // namespace mlir