diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -150,6 +150,26 @@ return true; } +// Return true if the transfer op can be converted to a MMA matrix load. +static bool transferReadFollowedByBroadcastSupportsMMAMatrixType( + vector::TransferReadOp readOp, bool useNvGpu) { + bool res = true; + if (readOp.getMask() || readOp.hasOutOfBoundsDim() || + readOp.getVectorType().getRank() != 1) + res = false; + if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) + res = false; + AffineMap map = readOp.getPermutationMap(); + OpBuilder b(readOp.getContext()); + + if (res && !useNvGpu) + return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map); + + llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: " + << res << "\n"; + return res; +} + // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { @@ -179,8 +199,27 @@ /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { - return broadcastOp.getVectorType().getRank() == 2 && - broadcastOp.getSource().getType().isa(); + auto res = broadcastOp.getVectorType().getRank() == 2 && + broadcastOp.getSource().getType().isa(); + llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n"; + return res; +} + +/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D +/// vector comes from a TransferReadOp. +static bool +broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp, + bool useNvGpu) { + auto readOp = broadcastOp.getSource().getDefiningOp(); + auto sourceVectorType = + broadcastOp.getSource().getType().dyn_cast(); + auto res = + !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType && + sourceVectorType.getRank() == 1 && + transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu); + llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res + << "\n"; + return res; } /// Return the MMA elementwise enum associated with `op` if it is supported. @@ -219,9 +258,10 @@ if (failed(contractOp)) return false; - // Handle vector.extract_strided_slice on registers containing - // matrixB and matrixC operands. vector.extract_strided_slice op - // is not supported on registers containing matrixA operands. + // Handle vector.extract_strided_slice on registers + // containing matrixB and matrixC operands. + // vector.extract_strided_slice op is not supported on + // registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) return (op->getResult(0).getType().cast() == (*contractOp).getRhs().getType().cast()); @@ -236,7 +276,9 @@ if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); + return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) || + transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead, + useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto extractStridedSlice = dyn_cast(op)) @@ -246,8 +288,10 @@ return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); - if (auto broadcast = dyn_cast(op)) - return broadcastSupportsMMAMatrixType(broadcast); + if (auto broadcast = dyn_cast(op)) { + return broadcastSupportsMMAMatrixType(broadcast) || + broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu); + } return elementwiseSupportsMMAMatrixType(op); } @@ -264,17 +308,20 @@ SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; - // Compute and insert the backwardSlice starting from currentOp. + // Compute and insert the backwardSlice starting from + // currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); - // Compute and insert the forwardSlice starting from currentOp. + // Compute and insert the forwardSlice starting from + // currentOp. forwardSlice.clear(); - // Special case for ForOp, we don't want to include the whole region but - // only the value using the region arguments. - // TODO: We should refine this to only care about the region arguments being - // converted to matrix type. + // Special case for ForOp, we don't want to include the + // whole region but only the value using the region + // arguments. + // TODO: We should refine this to only care about the + // region arguments being converted to matrix type. if (auto forOp = dyn_cast(currentOp)) { for (Value forOpResult : forOp.getResults()) getForwardSlice(forOpResult, &forwardSlice, forwardFilter); @@ -307,16 +354,20 @@ return; SetVector dependentOps = getSliceContract(contract, hasVectorDest, hasVectorSrc); - // If any instruction cannot use MMA matrix type drop the whole - // chain. MMA matrix are stored in an opaque type so they cannot be used - // by all operations. + // If any instruction cannot use MMA matrix type drop the + // whole chain. MMA matrix are stored in an opaque type so + // they cannot be used by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { - return !supportsMMaMatrixType(op, useNvGpu); + auto res = !supportsMMaMatrixType(op, useNvGpu); + if (res) + llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n"; + return res; })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); - // Sort the operations so that we can convert them in topological order. + // Sort the operations so that we can convert them in + // topological order. return topologicalSort(opToConvert); } @@ -443,7 +494,12 @@ static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); + if (!transferReadSupportsMMAMatrixType(op, + /*useNvGpu=*/false)) + return; + // Only transfers that return 2-D vectors are supported. + if (op.getVectorType().getRank() != 2) + return; std::optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -535,10 +591,11 @@ *warpMatrixInfo, /*transpose=*/!op.getPermutationMap().isMinorIdentity()); if (failed(params)) { - return op->emitError() - << "failed to convert vector.transfer_read to ldmatrix; this op " - "likely " - "should not be converted to a nvgpu.ldmatrix call."; + return op->emitError() << "failed to convert vector.transfer_read to " + "ldmatrix; this op " + "likely " + "should not be converted to a nvgpu.ldmatrix " + "call."; } // Adjust the load offset. @@ -572,7 +629,8 @@ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { op->emitError() << "Failed to deduce register fragment type during " - "conversion to distributed non-ldmatrix compatible load"; + "conversion to distributed non-ldmatrix compatible " + "load"; return failure(); } @@ -590,8 +648,8 @@ bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); - // If we are not transposing, then we can use vectorized loads. Otherwise, we - // must load each element individually. + // If we are not transposing, then we can use vectorized + // loads. Otherwise, we must load each element individually. if (!isTransposeLoad) { if (!loadedElType.isa()) { loadedElType = VectorType::get({1}, loadedElType); @@ -665,9 +723,9 @@ VectorType vecTy = op.getVectorType(); int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); - // When we are transposing the B operand, ldmatrix will only work if we have - // at least 8 rows to read and the width to read for the transpose is 128 - // bits. + // When we are transposing the B operand, ldmatrix will only + // work if we have at least 8 rows to read and the width to + // read for the transpose is 128 bits. if (!op.getPermutationMap().isMinorIdentity() && (bitWidth != 16 || vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) @@ -740,7 +798,8 @@ if (failed(mmaSyncFragmentInfo)) return failure(); - // Find the vector.transer_read whose result vector is being sliced. + // Find the vector.transer_read whose result vector is being + // sliced. auto transferReadOp = op.getVector().getDefiningOp(); if (!transferReadOp) return failure(); @@ -754,12 +813,13 @@ if (failed(ldFragmentInfo)) return failure(); - assert( - (mmaSyncFragmentInfo->elementsPerRegister == - ldFragmentInfo->elementsPerRegister) && - "Number of elements per register should be same for load and mma.sync"); + assert((mmaSyncFragmentInfo->elementsPerRegister == + ldFragmentInfo->elementsPerRegister) && + "Number of elements per register should be same for " + "load and mma.sync"); - // Create vector.extract_strided_slice op for thread-owned fragments. + // Create vector.extract_strided_slice op for thread-owned + // fragments. std::array strides = {1, 1}; // stride for extract slice is always 1. std::array sliceShape = { @@ -775,9 +835,11 @@ populateFromInt64AttrArray(op.getSizes(), sizes); ArrayRef warpVectorShape = op.getVectorType().getShape(); - // Compute offset in vector registers. Note that the mma.sync vector registers - // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector - // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. + // Compute offset in vector registers. Note that the mma.sync + // vector registers are shaped as numberOfFragments x + // numberOfRegistersPerfFragment. The vector registers can + // only be sliced along numberOfFragments, i.e., + // sliceOffset[0]. std::array sliceOffset = {0, 0}; if (offsets[0] && offsets[1]) @@ -842,7 +904,10 @@ /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap &valueMapping) { - assert(broadcastSupportsMMAMatrixType(op)); + // This op only catches the broadcasts that can directly + // convert to an MMA op. + if (!broadcastSupportsMMAMatrixType(op)) + return; OpBuilder b(op); const char *fragType = inferFragType(op); auto vecType = op.getVectorType(); @@ -853,11 +918,39 @@ valueMapping[op.getResult()] = matrix; } +/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. +static void +convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp, + llvm::DenseMap &valueMapping) { + // This op catches the broadcasts that cannot directly convert to an MMA + // op. + if (broadcastSupportsMMAMatrixType(broadcastOp)) + return; + if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp, + /*useNvGpu=*/false)) + return; + auto readOp = broadcastOp.getSource().getDefiningOp(); + assert(readOp && readOp.getVectorType().getRank() == 1); + // Handle broadcast by setting the stride to 0, unconditionally. + int64_t stride = 0; + const char *fragType = inferFragType(readOp); + gpu::MMAMatrixType type = gpu::MMAMatrixType::get( + broadcastOp.getVectorType().getShape(), + broadcastOp.getVectorType().getElementType(), fragType); + OpBuilder b(readOp); + bool isTranspose = false; + Value load = b.create( + readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(), + b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr()); + valueMapping[broadcastOp.getResult()] = load; +} + // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands) { - // Create a new loop before the existing one, with the extra operands. + // Create a new loop before the existing one, with the extra + // operands. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getIterOperands()); @@ -912,8 +1005,8 @@ auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) continue; - // Replace the yield of old value with the for op argument to make it easier - // to remove the dead code. + // Replace the yield of old value with the for op argument + // to make it easier to remove the dead code. yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; yieldOperands.push_back(it->second); } @@ -959,6 +1052,7 @@ convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); + convertBroadcastFromTransferReadOp(broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { convertForOp(forOp, valueMapping); } else if (auto yiledOp = dyn_cast(op)) { @@ -1027,6 +1121,8 @@ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); + getOperation()->dump(); + if (useNvGpu.getValue()) { if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) return signalPassFailure(); diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -4,7 +4,6 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map4 = affine_map<(d0) -> (d0, 0)> #map5 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @matmul @@ -118,6 +117,21 @@ // CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, +// %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { +// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f16 +// %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> +// %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> +// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst +// {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} +// : memref<16x16x16x16xf16>, vector<16x16xf16> +// %F = arith.divf %D, %E : vector<16x16xf16> +// vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> +// return +// } func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> @@ -126,9 +140,10 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> - %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} - : memref<16x16x16x16xf16>, vector<16x16xf16> + %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>} + : memref<16x16x16x16xf16>, vector<16xf16> + %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16> %F = arith.divf %D, %E : vector<16x16xf16> vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return @@ -141,12 +156,24 @@ // CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> +// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { +// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f16 +// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> +// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> +// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> +// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> +// return +// } func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> + %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16> + %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> @@ -160,12 +187,24 @@ // CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> +// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { +// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f16 +// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> +// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> +// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> +// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> +// return +// } func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> + %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16> + %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>