diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -54,6 +54,12 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of vector contraction transformation patterns +/// that express all vector.contract ops in terms of more elementary +/// extraction and reduction ops. +void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -183,6 +183,62 @@ }]; } +def Vector_ReductionOp : + Vector_Op<"reduction", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins StrAttr:$kind, AnyVector:$vector)>, + Results<(outs AnyType:$dest)> { + let summary = "reduction operation"; + let description = [{ + Reduces an 1-D vector "horizontally" into a scalar using the given + operation (add/mul/min/max for int/fp and and/or/xor for int only). + Note that these operations are restricted to 1-D vectors to remain + close to the corresponding LLVM intrinsics: + + http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics + + Examples: + ``` + %1 = vector.reduction "add", %0 : vector<16xf32> into f32 + + %3 = vector.reduction "xor", %2 : vector<4xi32> into i32 + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = [{ + $kind `,` $vector attr-dict `:` type($vector) `into` type($dest) + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + }]; +} + +// TODO(ajcbik): quick version with "fused" accumulator; next step +// will merge Reduction/ReductionV2 into one with +// an optional accumulator instead +def Vector_ReductionV2Op : + Vector_Op<"reductionv2", [NoSideEffect]>, + Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>, + Results<(outs AnyType:$dest)> { + let summary = "reduction operation"; + let description = [{ + As vector.reduction, but with a fused accumulator (add/mul for fp only). + }]; + let verifier = ?; + let assemblyFormat = [{ + $kind `,` $vector `,` $acc attr-dict `:` + type($vector) `,` type($acc) `into` type($dest) + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + }]; +} + def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -124,6 +124,7 @@ } namespace { + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -272,6 +273,100 @@ } }; +class VectorReductionOpConversion : public LLVMOpLowering { +public: + explicit VectorReductionOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ReductionOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast(op); + auto kind = reductionOp.kind(); + Type eltType = reductionOp.dest().getType(); + Type llvmType = lowering.convertType(eltType); + if (eltType.isInteger(32) || eltType.isInteger(64)) { + // Integer reductions: add/mul/min/max/and/or/xor. + if (kind == "add") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "mul") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "min") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "and") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "or") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "xor") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + + } else if (eltType.isF32() || eltType.isF64()) { + // Floating-point reductions: add/mul/min/max + if (kind == "add") { + Value zero = rewriter.create( + op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); + rewriter.replaceOpWithNewOp( + op, llvmType, zero, operands[0]); + } else if (kind == "mul") { + Value one = rewriter.create( + op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); + rewriter.replaceOpWithNewOp( + op, llvmType, one, operands[0]); + } else if (kind == "min") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + } + return matchFailure(); + } +}; + +// TODO(ajcbik): merge Reduction and ReductionV2 +class VectorReductionV2OpConversion : public LLVMOpLowering { +public: + explicit VectorReductionV2OpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context, + typeConverter) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast(op); + auto kind = reductionOp.kind(); + Type eltType = reductionOp.dest().getType(); + Type llvmType = lowering.convertType(eltType); + if (kind == "add") { + rewriter.replaceOpWithNewOp( + op, llvmType, operands[1], operands[0]); + return matchSuccess(); + } else if (kind == "mul") { + rewriter.replaceOpWithNewOp( + op, llvmType, operands[1], operands[0]); + return matchSuccess(); + } + return matchFailure(); + } +}; + class VectorShuffleOpConversion : public LLVMOpLowering { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1056,7 +1151,8 @@ VectorInsertStridedSliceOpDifferentRankRewritePattern, VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); - patterns.insert= 0); iterationBounds.push_back(lhsShape[lhsDimIndex]); continue; } // Get parallel dimension size from result shape. int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); + if (resDimIndex < 0) { + continue; + } assert(resDimIndex >= 0); assert(resVectorType != nullptr); iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -778,6 +778,72 @@ } }; +/// Progressive lowering of ConstractionOp. +class ContractionOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return matchFailure(); + + auto loc = op.getLoc(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type resType = op.getResultType(); + + // Find first batch dimension in lhs/rhs, and lower when found. + std::vector> batchDimMap = op.getBatchDimMap(); + if (!batchDimMap.empty()) { + // TODO(ajcbik): implement batch + return matchFailure(); + } + + // Collect contracting dimensions. + std::vector> contractingDimMap = + op.getContractingDimMap(); + DenseSet lhsContractingDimSet; + DenseSet rhsContractingDimSet; + for (auto &dimPair : contractingDimMap) { + lhsContractingDimSet.insert(dimPair.first); + rhsContractingDimSet.insert(dimPair.second); + } + + // Find free dimension in lhs/rhs, and lower first when found. + for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { + if (lhsContractingDimSet.count(i) == 0) { + // TODO(ajcbik): implement free + return matchFailure(); + } + } + for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { + if (rhsContractingDimSet.count(i) == 0) { + // TODO(ajcbik): implement free + return matchFailure(); + } + } + + // Only contraction dimensions remain. + if (!resType.isa() && lhsType.getRank() == 1 && + rhsType.getRank() == 1) { + // Handle reduction into scalar. + Value zero = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); + Value splat = rewriter.create(loc, lhsType, zero); + Value fma = + rewriter.create(loc, op.lhs(), op.rhs(), splat); + StringAttr kind = rewriter.getStringAttr("add"); + rewriter.replaceOpWithNewOp(op, resType, kind, fma, + op.acc()); + return matchSuccess(); + } + // TODO(ajcbik): implement more contraction + return matchFailure(); + } +}; + } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). @@ -792,3 +858,8 @@ OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } + +void mlir::vector::populateVectorContractLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s + +#dotp_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#dotp_trait = { + indexing_maps = #dotp_accesses, + iterator_types = ["reduction"] +} + +// CHECK-LABEL: func @extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> +// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]] : vector<4xf32>, f32 into f32 +// CHECK: return %[[R]] + +func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 + return %0 : f32 +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -42,12 +42,26 @@ } }; +struct TestVectorContractionConversion + : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + populateVectorContractLoweringPatterns(patterns, &getContext()); + applyPatternsGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace -static PassRegistration - pass("test-vector-to-vector-conversion", - "Test conversion patterns between ops in the vector dialect"); +static PassRegistration vector_to_vector_pass( + "test-vector-to-vector-conversion", + "Test conversion patterns between ops in the vector dialect"); static PassRegistration slices_pass( "test-vector-slices-conversion", "Test conversion patterns that lower slices ops in the vector dialect"); + +static PassRegistration + contraction_pass("test-vector-contraction-conversion", + "Test conversion patterns that lower contraction ops in " + "the vector dialect");