diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -138,6 +138,8 @@ using folded_std_constant_index = folded::ValueBuilder; using folded_std_constant_float = folded::ValueBuilder; +using folded_std_constant_int = folded::ValueBuilder; +using folded_std_constant = folded::ValueBuilder; using folded_std_dim = folded::ValueBuilder; using folded_std_muli = folded::ValueBuilder; } // namespace intrinsics diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -121,10 +121,6 @@ DenseMap promotionInfoMap; for (auto v : subViews) { SubViewOp subView = cast(v.getDefiningOp()); - auto viewType = subView.getType(); - // TODO(ntv): support more cases than just float. - if (!viewType.getElementType().isa()) - continue; auto promotionInfo = promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder); promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo)); @@ -136,10 +132,12 @@ auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; - // TODO(ntv): value to fill with should be related to the operation. - // For now, just use APFloat(0.0f). - auto t = subView.getType().getElementType().cast(); - Value fillVal = folded_std_constant_float(folder, APFloat(0.0f), t); + Value fillVal; + if (auto t = subView.getType().getElementType().dyn_cast()) + fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0)); + else if (auto t = + subView.getType().getElementType().dyn_cast()) + fillVal = folded_std_constant_int(folder, 0, t); // TODO(ntv): fill is only necessary if `promotionInfo` has a full local // view that is different from the partial local view and we are on the // boundary. @@ -214,13 +212,14 @@ if (!op.hasBufferSemantics()) return; - // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or - // nothing. + // TODO(ntv) some heuristic here to decide what to promote. Atm only float + // and integer buffers can be promoted. SetVector subViews; OpBuilder b(op); for (auto it : op.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) - subViews.insert(sv); + if (sv.getType().getElementType().isIntOrFloat()) + subViews.insert(sv); if (!subViews.empty()) { promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); toErase.push_back(op); diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -10,34 +10,32 @@ // CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> // CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -module { - func @matmul(%A: memref, %M: index, %N: index, %K: index) { - %c4 = constant 4 : index - %c3 = constant 3 : index - %c2 = constant 2 : index - %c0 = constant 0 : index - %c1 = constant 1 : index - %3 = view %A[%c0][%M, %K] : memref to memref - %4 = view %A[%c0][%K, %N] : memref to memref - %5 = view %A[%c0][%M, %N] : memref to memref - %6 = dim %3, 0 : memref - %7 = dim %3, 1 : memref - %8 = dim %4, 1 : memref - loop.for %arg4 = %c0 to %6 step %c2 { - loop.for %arg5 = %c0 to %8 step %c3 { - loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][%c1, %c1] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][%c1, %c1] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%11, %14, %17) : memref, memref, memref - } +func @matmul_f32(%A: memref, %M: index, %N: index, %K: index) { + %c4 = constant 4 : index + %c3 = constant 3 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %3 = view %A[%c0][%M, %K] : memref to memref + %4 = view %A[%c0][%K, %N] : memref to memref + %5 = view %A[%c0][%M, %N] : memref to memref + %6 = dim %3, 0 : memref + %7 = dim %3, 1 : memref + %8 = dim %4, 1 : memref + loop.for %arg4 = %c0 to %6 step %c2 { + loop.for %arg5 = %c0 to %8 step %c3 { + loop.for %arg6 = %c0 to %7 step %c4 { + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][%c1, %c1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][%c1, %c1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref + linalg.matmul(%11, %14, %17) : memref, memref, memref } } - return } + return } -// CHECK-LABEL: func @matmul(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK-LABEL: func @matmul_f32(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -74,3 +72,133 @@ // CHECK: dealloc %[[tmpA]] : memref<32xi8> // CHECK: dealloc %[[tmpB]] : memref<48xi8> // CHECK: dealloc %[[tmpC]] : memref<24xi8> + +// ----- + +func @matmul_f64(%A: memref, %M: index, %N: index, %K: index) { + %c4 = constant 4 : index + %c3 = constant 3 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %3 = view %A[%c0][%M, %K] : memref to memref + %4 = view %A[%c0][%K, %N] : memref to memref + %5 = view %A[%c0][%M, %N] : memref to memref + %6 = dim %3, 0 : memref + %7 = dim %3, 1 : memref + %8 = dim %4, 1 : memref + loop.for %arg4 = %c0 to %6 step %c2 { + loop.for %arg5 = %c0 to %8 step %c3 { + loop.for %arg6 = %c0 to %7 step %c4 { + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][%c1, %c1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][%c1, %c1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref + linalg.matmul(%11, %14, %17) : memref, memref, memref + } + } + } + return +} + +// CHECK-LABEL: func @matmul_f64(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[vA_f64:.*]] = std.subview {{.*}} : memref +// CHECK: %[[vB_f64:.*]] = std.subview {{.*}} : memref +// CHECK: %[[vC_f64:.*]] = std.subview {{.*}} : memref +/// +// CHECK: %[[tmpA_f64:.*]] = alloc() : memref<64xi8> +// CHECK: %[[fullA_f64:.*]] = std.view %[[tmpA_f64]][][{{.*}}] : memref<64xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialA_f64:.*]] = linalg.slice %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +/// +// CHECK: %[[tmpB_f64:.*]] = alloc() : memref<96xi8> +// CHECK: %[[fullB_f64:.*]] = std.view %[[tmpB_f64]][][{{.*}}] : memref<96xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialB_f64:.*]] = linalg.slice %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +/// +// CHECK: %[[tmpC_f64:.*]] = alloc() : memref<48xi8> +// CHECK: %[[fullC_f64:.*]] = std.view %[[tmpC_f64]][][{{.*}}] : memref<48xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialC_f64:.*]] = linalg.slice %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref + +// CHECK: linalg.fill(%[[fullA_f64]], {{.*}}) : memref, f64 +// CHECK: linalg.fill(%[[fullB_f64]], {{.*}}) : memref, f64 +// CHECK: linalg.fill(%[[fullC_f64]], {{.*}}) : memref, f64 +// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref, memref +// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref +// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref +// +// CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref, memref, memref +// +// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref +// +// CHECK: dealloc %[[tmpA_f64]] : memref<64xi8> +// CHECK: dealloc %[[tmpB_f64]] : memref<96xi8> +// CHECK: dealloc %[[tmpC_f64]] : memref<48xi8> + +// ----- + +func @matmul_i32(%A: memref, %M: index, %N: index, %K: index) { + %c4 = constant 4 : index + %c3 = constant 3 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %3 = view %A[%c0][%M, %K] : memref to memref + %4 = view %A[%c0][%K, %N] : memref to memref + %5 = view %A[%c0][%M, %N] : memref to memref + %6 = dim %3, 0 : memref + %7 = dim %3, 1 : memref + %8 = dim %4, 1 : memref + loop.for %arg4 = %c0 to %6 step %c2 { + loop.for %arg5 = %c0 to %8 step %c3 { + loop.for %arg6 = %c0 to %7 step %c4 { + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][%c1, %c1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][%c1, %c1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref + linalg.matmul(%11, %14, %17) : memref, memref, memref + } + } + } + return +} + +// CHECK-LABEL: func @matmul_i32(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[vA_i32:.*]] = std.subview {{.*}} : memref +// CHECK: %[[vB_i32:.*]] = std.subview {{.*}} : memref +// CHECK: %[[vC_i32:.*]] = std.subview {{.*}} : memref +/// +// CHECK: %[[tmpA_i32:.*]] = alloc() : memref<32xi8> +// CHECK: %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][][{{.*}}] : memref<32xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialA_i32:.*]] = linalg.slice %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +/// +// CHECK: %[[tmpB_i32:.*]] = alloc() : memref<48xi8> +// CHECK: %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][][{{.*}}] : memref<48xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialB_i32:.*]] = linalg.slice %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +/// +// CHECK: %[[tmpC_i32:.*]] = alloc() : memref<24xi8> +// CHECK: %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][][{{.*}}] : memref<24xi8> to memref +// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref +// CHECK: %[[partialC_i32:.*]] = linalg.slice %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref + +// CHECK: linalg.fill(%[[fullA_i32]], {{.*}}) : memref, i32 +// CHECK: linalg.fill(%[[fullB_i32]], {{.*}}) : memref, i32 +// CHECK: linalg.fill(%[[fullC_i32]], {{.*}}) : memref, i32 +// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref, memref +// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref +// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref +// +// CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref, memref, memref +// +// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref +// +// CHECK: dealloc %[[tmpA_i32]] : memref<32xi8> +// CHECK: dealloc %[[tmpB_i32]] : memref<48xi8> +// CHECK: dealloc %[[tmpC_i32]] : memref<24xi8>