Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1130,11 +1130,12 @@ def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">; def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">; def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">; +def GPU_ELEMENTWISE_OP_DIVF : StrEnumAttrCase<"DIVF">; def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp", "elementwise operation to apply to mma matrix", [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL, - GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> { + GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF, GPU_ELEMENTWISE_OP_DIVF]> { let cppNamespace = "::mlir::gpu"; let storageType = "::mlir::StringAttr"; let returnType = "::mlir::gpu::MMAElementwiseOp"; Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -304,6 +304,8 @@ return builder.create(loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MULF: return builder.create(loc, operands[0].getType(), operands); + case gpu::MMAElementwiseOp::DIVF: + return builder.create(loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MAXF: return createMinMaxF(builder, loc, operands[0], operands[1], /*isMin=*/false); Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -50,26 +50,7 @@ if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) return false; - // Check that the size matches what is natively supported. - VectorType lhsType = contract.lhs().getType().cast(); - VectorType rhsType = contract.rhs().getType().cast(); - VectorType accType = contract.acc().getType().cast(); - - std::tuple dim(lhsType.getDimSize(0), rhsType.getDimSize(1), - lhsType.getDimSize(1)); - if (lhsType.getElementType().isInteger(8) && - rhsType.getElementType().isInteger(8) && - accType.getElementType().isInteger(32) && - (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || - dim == std::make_tuple(16, 8, 32))) - return true; - - if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && - (accType.getElementType().isF16() || accType.getElementType().isF32()) && - (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || - dim == std::make_tuple(16, 8, 16))) - return true; - return false; + return true; } // Return the stide for the dimension 0 of |type| if it is a memref and has a @@ -95,8 +76,15 @@ return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; + AffineMap map = readOp.permutation_map(); + OpBuilder b(readOp.getContext()); + AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); + AffineExpr zero = b.getAffineConstantExpr(0); + auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, + readOp.getContext()); // TODO: Support transpose once it is added to GPU dialect ops. - if (!readOp.permutation_map().isMinorIdentity()) + // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). + if (!map.isMinorIdentity() && map != broadcastInnerDim) return false; return true; } @@ -142,6 +130,8 @@ return gpu::MMAElementwiseOp::MAXF; if (isa(op)) return gpu::MMAElementwiseOp::MINF; + if (isa(op)) + return gpu::MMAElementwiseOp::DIVF; return llvm::None; } @@ -166,6 +156,43 @@ return elementwiseSupportsMMAMatrixType(op); } +/// Return an unsorted slice without including all the elements within scf.for +/// region. +static SetVector getSliceContract(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + getBackwardSlice(currentOp, &backwardSlice, backwardFilter); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + // Special case for ForOp, we don't want to include the whole region but + // only the value using the region arguments. + // TODO: We should refine this to only care about the region arguments being + // converted to matrix type. + if (auto forOp = dyn_cast(currentOp)) { + for (Value forOpResult : forOp.getResults()) + getForwardSlice(forOpResult, &forwardSlice, forwardFilter); + for (BlockArgument &arg : forOp.getRegionIterArgs()) + getForwardSlice(arg, &forwardSlice, forwardFilter); + } else { + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + } + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return slice; +} + // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. static SetVector getOpToConvert(mlir::Operation *op) { @@ -182,7 +209,7 @@ if (opToConvert.contains(contract.getOperation())) return; SetVector dependentOps = - getSlice(contract, hasVectorDest, hasVectorSrc); + getSliceContract(contract, hasVectorDest, hasVectorSrc); // If any instruction cannot use MMA matrix type drop the whole // chaine. MMA matrix are stored in an opaque type so they cannot be used // by all operations. @@ -191,7 +218,8 @@ return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); - return opToConvert; + // Sort the operations so that we can convert them in topological order. + return topologicalSort(opToConvert); } namespace { @@ -309,6 +337,12 @@ assert(transferReadSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); + AffineMap map = op.permutation_map(); + // Handle broadcast by setting the stride to 0. + if (map.getResult(0).isa()) { + assert(map.getResult(0).cast().getValue() == 0); + stride = 0; + } assert(stride); const char *fragType = inferFragType(op); gpu::MMAMatrixType type = Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -106,3 +106,28 @@ vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } + +// CHECK-LABEL: func @matmul_fused_broadcast +// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, + %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} + : memref<16x16x16x16xf16>, vector<16x16xf16> + %F = arith.divf %D, %E : vector<16x16xf16> + vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +}