diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -54,6 +54,12 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of vector contraction transformation patterns +/// that express all vector.contract ops in terms of more elementary +/// extraction and reduction ops. +void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -216,6 +216,29 @@ }]; } +// TODO(ajcbik): quick version with "fused" accumulator; next step +// will merge Reduction/ReductionV2 into one with +// an optional accumulator instead +def Vector_ReductionV2Op : + Vector_Op<"reductionv2", [NoSideEffect]>, + Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>, + Results<(outs AnyType:$dest)> { + let summary = "reduction operation"; + let description = [{ + As vector.reduction, but with a fused accumulator (add/mul for fp only). + }]; + let verifier = ?; + let assemblyFormat = [{ + $kind `,` $vector `,` $acc attr-dict `:` + type($vector) `,` type($acc) `into` type($dest) + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + }]; +} + def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -340,6 +340,33 @@ } }; +// TODO(ajcbik): merge Reduction and ReductionV2 +class VectorReductionV2OpConversion : public LLVMOpLowering { +public: + explicit VectorReductionV2OpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context, + typeConverter) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast(op); + auto kind = reductionOp.kind(); + Type eltType = reductionOp.dest().getType(); + Type llvmType = lowering.convertType(eltType); + if (kind == "add") { + rewriter.replaceOpWithNewOp( + op, llvmType, operands[1], operands[0]); + return matchSuccess(); + } else if (kind == "mul") { + rewriter.replaceOpWithNewOp( + op, llvmType, operands[1], operands[0]); + return matchSuccess(); + } + return matchFailure(); + } +}; + class VectorShuffleOpConversion : public LLVMOpLowering { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1125,11 +1152,12 @@ VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); patterns.insert(ctx, converter); + VectorReductionV2OpConversion, VectorShuffleOpConversion, + VectorExtractElementOpConversion, VectorExtractOpConversion, + VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorInsertOpConversion, VectorOuterProductOpConversion, + VectorTypeCastOpConversion, VectorPrintOpConversion>( + ctx, converter); } namespace { @@ -1139,11 +1167,12 @@ } // namespace void LowerVectorToLLVMPass::runOnModule() { - // Perform progressive lowering of operations on "slices". - // Folding and DCE get rid of all non-leaking tuple ops. + // Perform progressive lowering of operations on "slices" and + // all contraction operations. Also applies folding and DCE. { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); + populateVectorContractLoweringPatterns(patterns, &getContext()); applyPatternsGreedily(getModule(), patterns); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -538,6 +538,7 @@ } namespace { + // Splits vector TransferReadOp into smaller TransferReadOps based on slicing // scheme of its unique ExtractSlicesOp user. struct SplitTransferReadOp : public OpRewritePattern { @@ -862,6 +863,72 @@ } }; +/// Progressive lowering of ConstractionOp. +class ContractionOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return matchFailure(); + + auto loc = op.getLoc(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type resType = op.getResultType(); + + // Find first batch dimension in lhs/rhs, and lower when found. + std::vector> batchDimMap = op.getBatchDimMap(); + if (!batchDimMap.empty()) { + // TODO(ajcbik): implement batch + return matchFailure(); + } + + // Collect contracting dimensions. + std::vector> contractingDimMap = + op.getContractingDimMap(); + DenseSet lhsContractingDimSet; + DenseSet rhsContractingDimSet; + for (auto &dimPair : contractingDimMap) { + lhsContractingDimSet.insert(dimPair.first); + rhsContractingDimSet.insert(dimPair.second); + } + + // Find free dimension in lhs/rhs, and lower first when found. + for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { + if (lhsContractingDimSet.count(i) == 0) { + // TODO(ajcbik): implement free + return matchFailure(); + } + } + for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { + if (rhsContractingDimSet.count(i) == 0) { + // TODO(ajcbik): implement free + return matchFailure(); + } + } + + // Only contraction dimensions remain. + if (!resType.isa() && lhsType.getRank() == 1 && + rhsType.getRank() == 1) { + // Handle reduction into scalar. + Value zero = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); + Value splat = rewriter.create(loc, lhsType, zero); + Value fma = + rewriter.create(loc, op.lhs(), op.rhs(), splat); + StringAttr kind = rewriter.getStringAttr("add"); + rewriter.replaceOpWithNewOp(op, resType, kind, fma, + op.acc()); + return matchSuccess(); + } + // TODO(ajcbik): implement more contraction + return matchFailure(); + } +}; + } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). @@ -876,3 +943,8 @@ OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } + +void mlir::vector::populateVectorContractLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s + +#dotp_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#dotp_trait = { + indexing_maps = #dotp_accesses, + iterator_types = ["reduction"] +} + +// CHECK-LABEL: func @extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> +// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]] +// CHECK: return %[[R]] : f32 + +func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xf32>, vector<4xf32> into f32 + return %0 : f32 +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -42,16 +42,29 @@ } }; +struct TestVectorContractionConversion + : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + populateVectorContractLoweringPatterns(patterns, &getContext()); + applyPatternsGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace namespace mlir { void registerTestVectorConversions() { - PassRegistration pass( + PassRegistration vectorToVectorPass( "test-vector-to-vector-conversion", "Test conversion patterns between ops in the vector dialect"); - PassRegistration slices_pass( + PassRegistration slicesPass( "test-vector-slices-conversion", "Test conversion patterns that lower slices ops in the vector dialect"); + + PassRegistration contractionPass( + "test-vector-contraction-conversion", + "Test conversion patterns that lower contract ops in the vector dialect"); } } // namespace mlir