diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -85,21 +85,51 @@ LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, ArrayRef targetShape); +/// Options that control the vector unrolling. +struct UnrollVectorOptions { + using FilterConstraintFnType = std::function; + /// Callback function that indicates whether vector unrolling should be + /// attempted on the operation. + FilterConstraintFnType filterConstraint = nullptr; + UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) { + filterConstraint = constraint; + return *this; + } + + using NativeShapeFnType = + std::function>(Operation *op)>; + /// Function that returns the shape of the vector to unroll to for a given + /// operation. The unrolling is aborted if the function returns `llvm::None`. + NativeShapeFnType nativeShape = nullptr; + UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { + nativeShape = fn; + return *this; + } + + /// Set the native shape to use for unrolling. + UnrollVectorOptions &setNativeShape(ArrayRef shape) { + SmallVector tsShape(shape.begin(), shape.end()); + nativeShape = [=](Operation *) -> Optional> { + return tsShape; + }; + return *this; + } +}; /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` /// declaratively. template struct UnrollVectorPattern : public OpRewritePattern { using FilterConstraintType = std::function; - UnrollVectorPattern( - ArrayRef targetShape, MLIRContext *context, - FilterConstraintType constraint = [](OpTy op) { return success(); }) - : OpRewritePattern(context), - targetShape(targetShape.begin(), targetShape.end()), - filter(constraint) {} + UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) + : OpRewritePattern(context), options(options) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - if (failed(filter(op))) + if (options.filterConstraint && failed(options.filterConstraint(op))) return failure(); + if (!options.nativeShape) { + return op.emitError("vector unrolling expects the native shape or native" + "shape call back function to be set"); + } auto unrollableVectorOp = dyn_cast(op.getOperation()); if (!unrollableVectorOp) @@ -107,19 +137,22 @@ auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) return failure(); - auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape); + Optional> targetShape = options.nativeShape(op); + if (!targetShape) + return op.emitError("failed to get target shape for vector unroll"); + auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); if (std::is_same::value) { - if (failed(unrollTransferWriteOp(rewriter, op, targetShape))) + if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) return failure(); rewriter.eraseOp(op); return success(); } if (op.getOperation()->getNumResults() != 1) return failure(); - auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape); + auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape); if (resultVector.size() != 1) return failure(); rewriter.replaceOp(op, resultVector.front()); @@ -127,8 +160,7 @@ } private: - SmallVector targetShape; - FilterConstraintType filter; + UnrollVectorOptions options; }; /// Split a vector.transfer operation into an unmasked fastpath and a slowpath. diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s + +func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>, + %init : vector<8x8xf32>) -> vector<8x8xf32> { + %0 = vector.contract + {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, k)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"]} + %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + return %0 : vector<8x8xf32> +} +// CHECK-LABEL: func @vector_contract_f32 +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: return + +func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>, + %init : vector<8x8xf16>) -> vector<8x8xf16> { + %0 = vector.contract + {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, k)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"]} + %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16> + return %0 : vector<8x8xf16> +} +// CHECK-LABEL: func @vector_contract_f16 +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: return diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -26,9 +27,10 @@ void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert>(ArrayRef{2, 2}, ctx); + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); patterns.insert>( - ArrayRef{2, 2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); @@ -113,16 +115,44 @@ struct TestVectorUnrollingPatterns : public PassWrapper { + TestVectorUnrollingPatterns() = default; + TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>(ArrayRef{2, 2}, ctx); - patterns.insert>( - ArrayRef{2, 2, 2}, ctx); + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + + if (unrollBasedOnType) { + UnrollVectorOptions::NativeShapeFnType nativeShapeFn = + [](Operation *op) -> Optional> { + vector::ContractionOp contractOp = cast(op); + SmallVector nativeShape = {4, 4, 2}; + if (auto floatType = contractOp.getLhsType() + .getElementType() + .dyn_cast()) { + if (floatType.getWidth() == 16) { + nativeShape[2] = 4; + } + } + return nativeShape; + }; + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn)); + } else { + patterns.insert>( + ctx, + UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); } + + Option unrollBasedOnType{ + *this, "unroll-based-on-type", + llvm::cl::desc("Set the unroll factor based on type of the operation"), + llvm::cl::init(false)}; }; struct TestVectorDistributePatterns @@ -165,9 +195,9 @@ MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; patterns.insert>( - ArrayRef{2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); patterns.insert>( - ArrayRef{2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns);