diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -180,6 +180,12 @@ if (failed(paddingValue)) return failure(hasDynamicShape); + // Assume that the iter types won't change in Linalg loop nests. + while (auto forOp = opOperand->get().getDefiningOp()) { + OpResult result = opOperand->get().cast(); + opOperand = &forOp.getOpOperandForResult(result); + } + // Cannot construct a static bounding box if the operand is not defined by an // ExtractSliceOp. auto sliceOp = opOperand->get().getDefiningOp(); diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -472,3 +472,83 @@ %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32> return %1 : tensor<1x?x?xf32> } + +// ----- + +// FILL: func @matmul_bias_add( +// FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<25x49xf32>, +// FILL-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<49x33xf32>, +// FILL-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<33xf32>, +// FILL-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<25x33xf32>, +// FILL-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<25x33xf32>) +func @matmul_bias_add(%arg0: tensor<25x49xf32>, + %arg1: tensor<49x33xf32>, + %arg2: tensor<33xf32>, + %arg3: tensor<25x33xf32>, + %arg4: tensor<25x33xf32>) -> tensor<25x33xf32> { + %c1 = arith.constant 1 : index + %c49 = arith.constant 49 : index + %c24 = arith.constant 24 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c25 = arith.constant 25 : index + %c33 = arith.constant 33 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = scf.for %arg5 = %c0 to %c25 step %c8 iter_args(%arg6 = %arg4) -> (tensor<25x33xf32>) { + %1 = affine.min affine_map<(d0) -> (-d0 + 25, 8)>(%arg5) + %2 = affine.min affine_map<(d0) -> (8, -d0 + 25)>(%arg5) + %3 = scf.for %arg7 = %c0 to %c33 step %c16 iter_args(%arg8 = %arg6) -> (tensor<25x33xf32>) { + %4 = affine.min affine_map<(d0) -> (-d0 + 33, 16)>(%arg7) + %5 = tensor.extract_slice %arg3[%arg5, %arg7] [%1, %4] [1, 1] : tensor<25x33xf32> to tensor + + // FILL: %[[PAD_IN:.+]] = tensor.pad + // FILL: %[[FILL:.+]] = linalg.fill(%{{.*}}, %[[PAD_IN]] + // FILL: %[[SLICED_FILL:.+]] = tensor.extract_slice %[[FILL]] + // FILL: %[[R0:.+]] = scf.for {{.*}} iter_args(%[[IT:.+]] = %[[SLICED_FILL]]) -> (tensor) { + // FILL: %[[T0:.+]] = tensor.extract_slice %[[ARG0]] + // FILL: %[[T1:.+]] = tensor.extract_slice %[[ARG1]] + // FILL: %[[T2:.+]] = tensor.extract_slice %[[IT]] + // FILL: %[[PAD_T0:.+]] = tensor.pad %[[T0]] + // FILL: %[[PAD_T1:.+]] = tensor.pad %[[T1]] + // FILL: %[[PAD_T2:.+]] = tensor.pad %[[T2]] + // FILL: %{{.+}} = linalg.matmul ins(%[[PAD_T0]], %[[PAD_T1]] + // FILL-SAME: outs(%[[PAD_T2]] + + %6 = linalg.fill(%cst, %5) : f32, tensor -> tensor + %7 = tensor.dim %6, %c0 : tensor + %8 = tensor.dim %6, %c1 : tensor + %9 = scf.for %arg9 = %c0 to %c49 step %c24 iter_args(%arg10 = %6) -> (tensor) { + %15 = affine.min affine_map<(d0) -> (24, -d0 + 49)>(%arg9) + %16 = tensor.extract_slice %arg0[%arg5, %arg9] [%1, %15] [1, 1] : tensor<25x49xf32> to tensor + %17 = tensor.extract_slice %arg1[%arg9, %arg7] [%15, %4] [1, 1] : tensor<49x33xf32> to tensor + %18 = tensor.extract_slice %arg10[0, 0] [%7, %8] [1, 1] : tensor to tensor + %19 = linalg.matmul ins(%16, %17 : tensor, tensor) outs(%18 : tensor) -> tensor + %20 = tensor.insert_slice %19 into %arg10[0, 0] [%7, %8] [1, 1] : tensor into tensor + scf.yield %20 : tensor + } + + // FILL: %[[T3:.+]] = tensor.extract_slice %[[ARG2]] + // FILL: %[[T4:.+]] = tensor.extract_slice %{{.+}} + // FILL: %[[PAD_R0:.+]] = tensor.pad %[[R0]] + // FILL: %[[PAD_T3:.+]] = tensor.pad %[[T3]] + // FILL: %[[PAD_T4:.+]] = tensor.pad %[[T4]] + // FILL: %{{.+}} = linalg.generic + // FILL-SAME: ins(%[[PAD_R0]], %[[PAD_T3]] + // FILL-SAME: outs(%[[PAD_T4]] + + %10 = affine.min affine_map<(d0) -> (16, -d0 + 33)>(%arg7) + %11 = tensor.extract_slice %arg2[%arg7] [%10] [1] : tensor<33xf32> to tensor + %12 = tensor.extract_slice %arg8[%arg5, %arg7] [%2, %10] [1, 1] : tensor<25x33xf32> to tensor + %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %11 : tensor, tensor) outs(%12 : tensor) { + ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): + %15 = arith.addf %arg9, %arg10 : f32 + linalg.yield %15 : f32 + } -> tensor + %14 = tensor.insert_slice %13 into %arg8[%arg5, %arg7] [%2, %10] [1, 1] : tensor into tensor<25x33xf32> + scf.yield %14 : tensor<25x33xf32> + } + scf.yield %3 : tensor<25x33xf32> + } + return %0 : tensor<25x33xf32> +}