diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1304,6 +1304,13 @@ "no defining op on contract operands"); } + auto lhsOperandTy = getElementTypeOrSelf(lhsDefOp.getOperand().getType()); + auto rhsOperandTy = getElementTypeOrSelf(rhsDefOp.getOperand().getType()); + if (lhsOperandTy != rhsOperandTy) { + return rewriter.notifyMatchFailure(contractOp, + "lhs/rhs extf operation must match"); + } + rewriter.replaceOpWithNewOp( contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), contractOp.getAcc(), contractOp.getIndexingMapsAttr(), @@ -1313,11 +1320,50 @@ } }; +/// Pattern to materlize arithmetic extensions out of floating point data +/// types from vector contraction operations. This is a specialized case +/// when `vector.contraction` has inputs of non-matching types. +struct MaterializeContractionOpExt + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + Value lhs = contractOp.getLhs(); + Value rhs = contractOp.getRhs(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + if (getElementTypeOrSelf(lhsTy) == getElementTypeOrSelf(rhsTy)) { + return rewriter.notifyMatchFailure(contractOp, "lhs/rhs match correctly"); + } + + auto loc = contractOp.getLoc(); + auto resultTy = contractOp.getType(); + if (getElementTypeOrSelf(lhs) != getElementTypeOrSelf(resultTy)) { + lhs = rewriter.create( + loc, lhsTy.clone(getElementTypeOrSelf(resultTy)), lhs); + } + + if (getElementTypeOrSelf(rhs) != getElementTypeOrSelf(resultTy)) { + rhs = rewriter.create( + loc, rhsTy.clone(getElementTypeOrSelf(resultTy)), rhs); + } + + rewriter.replaceOpWithNewOp( + contractOp, resultTy, lhs, rhs, contractOp.getAcc(), + contractOp.getIndexingMapsAttr(), contractOp.getIteratorTypesAttr()); + + return success(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } void mlir::vector::populateVectorMaskMaterializationPatterns( diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir --- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir +++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir @@ -15,4 +15,18 @@ %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> return %result : vector<64x64xf32> -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: @no_fold_arith_extf_into_contract +// CHECK: %[[EXTF0:.+]] = arith.extf %arg0 : vector<2x4xf16> to vector<2x4xf32> +// CHECK: %[[EXTF1:.+]] = arith.extf %arg1 : vector<4x15xbf16> to vector<4x15xf32> +// CHECK: %[[CONTRACT:.+]] = vector.contract +// CHECK-SAME: %[[EXTF0]], %[[EXTF1]], %arg2 : vector<2x4xf32>, vector<4x15xf32> into vector<2x15xf32> +func.func @no_fold_arith_extf_into_contract(%arg0 : vector<2x4xf16>, %arg1 : vector<4x15xbf16>, %arg2 : vector<2x15xf32>) -> vector<2x15xf32> { + %0 = arith.extf %arg0 : vector<2x4xf16> to vector<2x4xf32> + %1 = arith.extf %arg1 : vector<4x15xbf16> to vector<4x15xf32> + %2 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %arg2 : vector<2x4xf32>, vector<4x15xf32> into vector<2x15xf32> + return %2 : vector<2x15xf32> +}