diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -38,18 +38,25 @@ // TODO(andydavis, ntv) Add an attribute to specify a different algebra // with operators other than the current set: {*, +}. def Vector_ContractionOp : - Vector_Op<"contract", [NoSideEffect]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, + Vector_Op<"contract", [NoSideEffect, + PredOpTrait<"first operand lhs and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"second operand rhs and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>, + PredOpTrait<"third operand acc and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, - Results<(outs AnyVector)> { + Results<(outs AnyType)> { let summary = "vector contraction operation"; let description = [{ Computes the sum of products of vector elements along contracting dimension pairs from 2 vectors of rank M and N respectively, adds this intermediate result to the accumulator argument of rank K, and returns a vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + - num_batch_dims (see dimension type descriptions below)). + num_batch_dims (see dimension type descriptions below)). For K = 0 (no + free or batch dimensions), the accumulator and output are a scalar. Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp) specify the dynamic dimension sizes of valid data within the lhs/rhs vector @@ -59,7 +66,7 @@ the list represents an iterator with one of the following types: *) "reduction": reduction dimensions are present in the lhs and rhs - arguments but not in the output (or optional accumulator + arguments but not in the output (and accumulator argument). These are the dimensions along which the vector contraction op computes the sum of products, and contracting dimension pair dimension sizes must match @@ -81,30 +88,44 @@ Examples: - // 2D vector contraction with one contracting dimension (matmul). + // Simple dot product (K = 0). + #contraction_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["reduction"] + } + %3 = vector.contract #contraction_trait %0, %1, %2 + : vector<10xf32>, vector<10xf32> into f32 + + // 2D vector contraction with one contracting dimension (matmul, K = 2). #contraction_accesses = [ - (i, j, k) -> (i, k), - (i, j, k) -> (k, j), - (i, j, k) -> (i, j) + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = [parallel, parallel, reduction] + iterator_types = ["parallel", "parallel", "reduction"] } %3 = vector.contract #contraction_trait %0, %1, %2 : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32> // 4D to 3D vector contraction with two contracting dimensions and - // one batch dimension. + // one batch dimension (K = 3). #contraction_accesses = [ - (b0, f0, f1, c0, c1) -> (c0, b0, c1, f0), - (b0, f0, f1, c0, c1) -> (b0, c1, c0, f1), - (b0, f0, f1, c0, c1) -> (b0, f0, f1) + affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>, + affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>, + affine_map<(b0, f0, f1, c0, c1) -> (b0, f0, f1)> ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = [parallel, parallel, parallel reduction, reduction] + iterator_types = ["parallel", "parallel", "parallel", + "reduction", "reduction"] } %4 = vector.contract #contraction_trait %0, %1, %2 @@ -128,8 +149,8 @@ VectorType getRhsType() { return rhs().getType().cast(); } - VectorType getAccType() { - return acc().getType().cast(); + Type getAccType() { + return acc().getType(); } VectorType getLHSVectorMaskType() { if (llvm::size(masks()) != 2) return VectorType(); @@ -139,8 +160,8 @@ if (llvm::size(masks()) != 2) return VectorType(); return getOperand(4).getType().cast(); } - VectorType getResultType() { - return getResult().getType().cast(); + Type getResultType() { + return getResult().getType(); } ArrayRef getTraitAttrNames(); SmallVector getIndexingMaps(); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -80,7 +80,7 @@ OpAsmParser::OperandType accInfo; SmallVector masksInfo; SmallVector types; - Type resultVectorType; + Type resultType; auto loc = parser.getCurrentLocation(); DictionaryAttr dictAttr; // TODO(andydavis, ntv) Unify linalg op attribute parsing. @@ -91,11 +91,11 @@ parser.parseTrailingOperandList(masksInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || - parser.parseKeywordType("into", resultVectorType) || + parser.parseKeywordType("into", resultType) || parser.resolveOperand(lhsInfo, types[0], result.operands) || parser.resolveOperand(rhsInfo, types[1], result.operands) || - parser.resolveOperand(accInfo, resultVectorType, result.operands) || - parser.addTypeToList(resultVectorType, result.types)) + parser.resolveOperand(accInfo, resultType, result.operands) || + parser.addTypeToList(resultType, result.types)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); @@ -148,8 +148,7 @@ } static bool verifyOutputShape( - VectorType lhsType, VectorType rhsType, VectorType accType, - VectorType resType, + VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector> &contractingDimMap, const std::vector> &batchDimMap) { DenseSet lhsContractingDimSet; @@ -177,18 +176,39 @@ expectedResultDims.push_back(rhsType.getDimSize(i)); } - // Verify dimension from 'resType' against 'expectedResultDims'. - if (resType.getShape().size() != expectedResultDims.size() || - accType.getShape().size() != expectedResultDims.size()) - return false; - for (int64_t i = 0, e = resType.getRank(); i < e; ++i) { - if (resType.getDimSize(i) != expectedResultDims[i] || - accType.getDimSize(i) != expectedResultDims[i]) + // Verify 'expectedResultDims'. + if (expectedResultDims.size() == 0) { + // No batch or free dimension implies a scalar result. + if (resType.isa() || accType.isa()) return false; + + } else { + // At least one batch or free dimension implies a vector result. + auto resVectorType = resType.dyn_cast(); + auto accVectorType = accType.dyn_cast(); + if (!resVectorType || !accVectorType) + return false; + + // Verify dimension from 'resType' against 'expectedResultDims'. + if (resVectorType.getShape().size() != expectedResultDims.size() || + accVectorType.getShape().size() != expectedResultDims.size()) + return false; + for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) { + if (resVectorType.getDimSize(i) != expectedResultDims[i] || + accVectorType.getDimSize(i) != expectedResultDims[i]) + return false; + } } return true; } +static unsigned getTypeRank(Type type) { + auto vectorType = type.dyn_cast(); + if (vectorType) + return vectorType.getShape().size(); + return 0; +} + static LogicalResult verify(ContractionOp op) { auto lhsType = op.getLhsType(); auto rhsType = op.getRhsType(); @@ -209,11 +229,12 @@ if (map.getNumSymbols() != 0) return op.emitOpError("expected indexing map ") << index << " to have no symbols"; + unsigned rank = getTypeRank(op.getOperand(index).getType()); + if (map.getNumDims() == 0 && map.getNumResults() == 0 && rank == 0) + continue; // (i) -> () is parsed into empty map; accept for rank=0 if (map.getNumDims() != numIterators) return op.emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; - auto operandType = op.getOperand(index).getType().cast(); - unsigned rank = operandType.getShape().size(); if (map.getNumResults() != rank) return op.emitOpError("expected indexing map ") << index << " to have " << rank << " number of outputs"; @@ -291,7 +312,7 @@ void ContractionOp::getIterationBounds( SmallVectorImpl &iterationBounds) { auto lhsShape = getLhsType().getShape(); - auto resShape = getResultType().getShape(); + auto resVectorType = getResultType().dyn_cast(); SmallVector indexingMaps(getIndexingMaps()); SmallVector iterationShape; for (auto it : llvm::enumerate(iterator_types())) { @@ -308,7 +329,8 @@ // Get parallel dimension size from result shape. int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); assert(resDimIndex >= 0); - iterationBounds.push_back(resShape[resDimIndex]); + assert(resVectorType != nullptr); + iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); } } diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -707,6 +707,25 @@ // ----- +#contraction_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> + ] +#contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["parallel", "parallel", "reduction"] + } +func @contraction(%arg0: vector<4x3xi32>, + %arg1: vector<3x7xf32>, + %arg2: vector<4x7xf32>) -> vector<4x7xf32> { + // expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}} + %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2 + : vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32> +} + +// ----- + func @create_mask() { %c2 = constant 2 : index %c3 = constant 3 : index diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -127,6 +127,26 @@ return %1: vector<2x2x16xf32> } +#contraction_to_scalar_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#contraction_to_scalar_trait = { + indexing_maps = #contraction_to_scalar_accesses, + iterator_types = ["reduction"] +} +// CHECK-LABEL: contraction_to_scalar +func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { + // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0: f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 + : vector<10xf32>, vector<10xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + #contraction_accesses0 = [ affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>, affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,