diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -464,6 +464,18 @@ /// ```mlir /// affine_map<(d0) -> (0, 0, d0, 0)> /// ``` +/// Example 4: +/// +/// ```mlir +/// affine_map<(d0, d1, d2) -> (d0, 0)> +/// ``` +/// +/// returns: +/// +/// ```mlir +/// affine_map<(d0, d1) -> (d0, 0, 0)> +/// ``` + AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map); /// Concatenates a list of `maps` into a single AffineMap, stepping over diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -498,17 +498,30 @@ bool AffineMap::isProjectedPermutation() const { if (getNumSymbols() > 0) return false; + SmallVector seen(getNumInputs(), false); + unsigned numValidResults = 0; + // A projected permutation can have, at most, only one instance of each input + // dimension in the result expressions. Zeros are allowed as long as the total + // number of result expressions is lower or equal than the number of input + // expressions. for (auto expr : getResults()) { if (auto dim = expr.dyn_cast()) { if (seen[dim.getPosition()]) return false; seen[dim.getPosition()] = true; - continue; + ++numValidResults; + } else if (auto constExpr = expr.dyn_cast()) { + if (constExpr.getValue() != 0) + return false; + ++numValidResults; + } else { + // Neither dim nor zero. + return false; } - return false; } - return true; + + return numValidResults <= getNumInputs(); } bool AffineMap::isPermutation() const { @@ -702,7 +715,15 @@ // Start with all the results as 0. SmallVector exprs(map.getNumInputs(), zero); for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { - // Reverse each dimension existing in the oringal map result. + // Skip zeros from input map. 'exprs' is already initialized to zero. + if (auto constExpr = map.getResult(i).dyn_cast()) { + assert(constExpr.getValue() == 0 && + "Unexpected constant in projected permutation"); + (void)constExpr; + continue; + } + + // Reverse each dimension existing in the original map result. exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); } return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -857,3 +857,61 @@ return %red : tensor<4xf32> } +// ----- + +// CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> + +// CHECK-LABEL: func @explicit_broadcast( +func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> { + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M5]]} : tensor<4x1xf32>, vector<4x4xf32> + // CHECK: subf {{.*}} : vector<4x4xf32> + // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> + %c0 = constant 0.0 : f32 + %init = linalg.init_tensor [4, 4] : tensor<4x4xf32> + %fill = linalg.fill(%c0, %init) : f32, tensor<4x4xf32> -> tensor<4x4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, 0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>) + outs(%fill : tensor<4x4xf32>) { + ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): + %40 = subf %arg7, %arg8 : f32 + linalg.yield %40 : f32 + } -> tensor<4x4xf32> + return %red : tensor<4x4xf32> +} + +// ----- + +// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK-DAG: #[[$M7:.*]] = affine_map<(d0) -> (d0, 0)> + +// CHECK-LABEL: func @fused_broadcast_red_2d +func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4xf32> { + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M7]]} : tensor<4xf32>, vector<4x4xf32> + // CHECK: subf {{.*}} : vector<4x4xf32> + // CHECK: math.exp {{.*}} : vector<4x4xf32> + // CHECK: addf {{.*}} : vector<4x4xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> + %c0 = constant 0.0 : f32 + %init = linalg.init_tensor [4] : tensor<4xf32> + %fill = linalg.fill(%c0, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, 0)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>) + outs(%fill : tensor<4xf32>) { + ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): + %40 = subf %arg7, %arg8 : f32 + %41 = math.exp %40 : f32 + %42 = addf %41, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor<4xf32> + return %red : tensor<4xf32> +}