diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -1,36 +1,3 @@ -# Declare a function to generate ODS with mlir-linalg-ods-gen -function(add_linalg_ods_tc_gen tc_filename output_file) - set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename}) - set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.td) - set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.cpp.inc) - set_source_files_properties( - ${GEN_ODS_FILE} - PROPERTIES GENERATED TRUE) - set_source_files_properties( - ${GEN_CPP_FILE} - PROPERTIES GENERATED TRUE) - add_custom_command( - OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE} - COMMAND ${MLIR_LINALG_ODS_GEN_EXE} -gen-ods-decl ${TC_SOURCE} > ${GEN_ODS_FILE} - COMMAND ${MLIR_LINALG_ODS_GEN_EXE} -gen-impl ${TC_SOURCE} > ${GEN_CPP_FILE} - MAIN_DEPENDENCY - ${TC_SOURCE} - DEPENDS - ${MLIR_LINALG_ODS_GEN_EXE} - ${MLIR_LINALG_ODS_GEN_TARGET} - VERBATIM) - add_custom_target( - MLIR${output_file}TcIncGen - DEPENDS - ${MLIR_LINALG_ODS_GEN_EXE} - ${MLIR_LINALG_ODS_GEN_TARGET} - ${GEN_ODS_FILE} ${GEN_CPP_FILE}) - # Setup the file dependencies needed for the subsequent tablegen step. - # TODO: Once there is only one way of generating named ops remove this parent - # scope manipulation and implement the tablegen generation in the same scope. - set(LLVM_TARGET_DEPENDS ${LLVM_TARGET_DEPENDS} ${GEN_ODS_FILE} PARENT_SCOPE) -endfunction() - # Declare a function to generate ODS with mlir-linalg-ods-yaml-gen function(add_linalg_ods_yaml_gen yaml_ast_file output_file) set(YAML_AST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${yaml_ast_file}) @@ -56,25 +23,21 @@ ${MLIR_LINALG_ODS_YAML_GEN_EXE} ${MLIR_LINALG_ODS_YAML_GEN_TARGET} ${GEN_ODS_FILE} ${GEN_CPP_FILE}) - # Setup the file dependencies needed for the subsequent tablegen step. - # TODO: Once there is only one way of generating named ops remove this parent - # scope manipulation and implement the tablegen generation in the same scope. - set(LLVM_TARGET_DEPENDS ${LLVM_TARGET_DEPENDS} ${GEN_ODS_FILE} PARENT_SCOPE) + list(APPEND LLVM_TARGET_DEPENDS ${GEN_ODS_FILE}) + set(LLVM_TARGET_DEPENDS ${LLVM_TARGET_DEPENDS} PARENT_SCOPE) endfunction() -# TODO: Delete tc generation and replace with the YAML variant once all ops are -# ported. At the same time, move the YAML and TableGen generation to the same -# scope to avoid the at a distance dependency manipulation via -# LLVM_TARGET_DEPENDS. +# NOTE: LLVM_TARGET_DEPENDS gets picked up by tablegen targets to add file +# level dependencies. This is gross but CMake requires depending on both +# targets and generated files, and it must be done when the custom target is +# declared (there is no way to add after the fact). set(LLVM_TARGET_DEPENDS "") -add_linalg_ods_tc_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps) add_linalg_ods_yaml_gen(LinalgNamedStructuredOps.yaml LinalgNamedStructuredOps) # Provide a short name for all external dependency that needs to # include Linalg in ODS add_custom_target(LinalgOdsGen DEPENDS - MLIRLinalgNamedStructuredOpsTcIncGen MLIRLinalgNamedStructuredOpsYamlIncGen ) add_dependencies(mlir-headers LinalgOdsGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ /dev/null @@ -1,7 +0,0 @@ -ods_def -implements_interface : -def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { - // TODO: ideally something closer to - // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = AddIOp(C(m, n), MulIOp(SignExtendIOp32(A(m, k)), SignExtendIOp32(B(k, n)))); -} diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -671,8 +671,6 @@ // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// -// This file is auto-generated from a TC def specification. -include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td" include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td" #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2687,7 +2687,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp) DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp) -#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc" #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -66,7 +66,6 @@ mlir-capi-pass-test mlir-capi-sparse-tensor-test mlir-cpu-runner - mlir-linalg-ods-gen mlir-linalg-ods-yaml-gen mlir-lsp-server mlir-opt diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -26,14 +26,14 @@ // CHECK: : tensor to tensor<4x3xi8> // CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3xi32> -// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) +// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) // CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: %[[sTD:.*]] = tensor.extract_slice %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xi32> to tensor // CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor // CHECK: scf.yield %[[TD]] : tensor // CHECK: scf.yield %[[TD2]] : tensor // CHECK: scf.yield %[[TD1]] : tensor - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} + %0 = linalg.matmul {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor @@ -82,19 +82,19 @@ // CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor // CHECK-1DIM-TILE: %[[TC:[0-9a-z]+]]: tensor) -> tensor { // CHECK-1DIM-TILE-NOT: scf.for -// CHECK-1DIM-TILE: linalg.matmul_i8_i8_i32 ins(%[[TA]], %[[TB]] : tensor, tensor) outs(%[[TC]] : tensor) -> tensor +// CHECK-1DIM-TILE: linalg.matmul ins(%[[TA]], %[[TB]] : tensor, tensor) outs(%[[TC]] : tensor) -> tensor func @matmul_partially_padded_tensors( %arg0: tensor, %arg1: tensor<8x?xi8>, %arg2: tensor) -> tensor { - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} + %0 = linalg.matmul {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor<8x?xi8>) outs(%arg2: tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @matmul_partially_padded_tensors( -// CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: linalg.matmul ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32> // Check only the the input operands are padded. @@ -112,7 +112,7 @@ // CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> // CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8> -// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) +// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) // CHECK-1DIM-TILE: outs(%[[sTC]] : tensor) -> tensor // Check that the tile-and-pad transformation actually introduces the padding diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -538,35 +538,6 @@ // ----- -// CHECK-LABEL: func @matmul_i8_i8_i32 -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<4x6xi8> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<6x12xi8> -// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32> -func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) { - // CHECK-DAG: %[[C0:.*]] = constant 0 : index - // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32> - // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<12x6xi8> - // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32> - // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32> - // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<12x6xi8> to vector<12x6xi32> - // - // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. - // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract - // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} - // CHECK-SAME: %[[V0_32]], %[[V1_32]], %[[VEC_C0]] - // CHECK-SAME: vector<4x6xi32>, vector<12x6xi32> into vector<4x12xi32> - // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32> - // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} - // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32> - linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>) - outs(%c: memref<4x12xi32>) - return -} - -// ----- - // CHECK-LABEL: func @pad_static( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x2xf32>, %[[PAD:.*]]: f32 // CHECK-NOT: linalg.pad_tensor diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir deleted file mode 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir +++ /dev/null @@ -1,111 +0,0 @@ -// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \ -// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \ -// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16 vectorize" | \ -// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \ -// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \ -// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \ - -// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts -mlir-disable-threading | \ -// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ -// Activate to dump assembly -// R_UN: -dump-object-file -object-filename=/tmp/a.o \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// Use tee to both print to stderr and FileCheck -// RUN: tee -a /dev/stderr | FileCheck %s - - -!elem_type_a = type i8 -!elem_type_b = type i8 -!elem_type_c = type i32 -!row_major_A = type memref<${M}x${K}x!elem_type_a> -!row_major_B = type memref<${K}x${N}x!elem_type_b> -!row_major_C = type memref<${M}x${N}x!elem_type_c> - -func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C) -// TODO: activate manually for now. -// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} -{ - linalg.matmul_i8_i8_i32 ins(%a, %b : !row_major_A, !row_major_B) - outs(%c: !row_major_C) - return -} - -func @print_perf(%iters: index, %total_time: f64) { - %c2 = constant 2 : index - %cM = constant ${M} : index - %cN = constant ${N} : index - %cK = constant ${K} : index - - %mn = muli %cM, %cN : index - %mnk = muli %mn, %cK : index - - // 2*M*N*K. - %flops_per_iter = muli %c2, %mnk : index - %flops = muli %iters, %flops_per_iter : index - %flops_i64 = index_cast %flops : index to i64 - %flops_f = sitofp %flops_i64 : i64 to f64 - %flops_per_s = divf %flops_f, %total_time : f64 - vector.print %flops_per_s : f64 - - return -} - -func @main() { - %v0 = constant 0 : !elem_type_c - %v1 = constant 1 : !elem_type_a - - %A = memref.alloc() : !row_major_A - %B = memref.alloc() : !row_major_B - %C = memref.alloc() : !row_major_C - - linalg.fill(%v1, %A) : !elem_type_a, !row_major_A - linalg.fill(%v1, %B) : !elem_type_b, !row_major_B - linalg.fill(%v0, %C) : !elem_type_c, !row_major_C - - %c0 = constant 0: index - %c1 = constant 1: index - %iters = constant 100: index - - /// Run and dump performance for matmul. - /// Preheating run: - scf.for %arg0 = %c0 to %iters step %c1 { - linalg.fill(%v0, %C) : !elem_type_c, !row_major_C - call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () - } - %t_start_matmul = call @rtclock() : () -> f64 - scf.for %arg0 = %c0 to %iters step %c1 { - // linalg.matmul writes %C in place, need to reset it to zero every time. - // This is accounts for about 10-15% perf hit on small sizes. - // Once linalg on tensors is ready, fusing fill at the register level will - // be easy. - linalg.fill(%v0, %C) : !elem_type_c, !row_major_C - call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () - } - %t_end_matmul = call @rtclock() : () -> f64 - %tmatmul = subf %t_end_matmul, %t_start_matmul: f64 - call @print_perf(%iters, %tmatmul) : (index, f64) -> () - - // CHECK: {{^0$}} - %C_ref = memref.alloc() : !row_major_C - linalg.fill(%v0, %C_ref) : !elem_type_c, !row_major_C - linalg.matmul_i8_i8_i32 ins(%A, %B : !row_major_A, !row_major_B) - outs(%C_ref: !row_major_C) - %res = memref.cast %C : !row_major_C to memref<*xi32> - %exp = memref.cast %C_ref : !row_major_C to memref<*xi32> - %errors = call @verifyMemRefI32(%res, %exp) : (memref<*xi32>, memref<*xi32>) -> i64 - vector.print %errors : i64 - memref.dealloc %C_ref : !row_major_C - - memref.dealloc %A : !row_major_A - memref.dealloc %B : !row_major_B - memref.dealloc %C : !row_major_C - - return -} - -func private @rtclock() -> f64 -func private @verifyMemRefI32(memref<*xi32>, memref<*xi32>) -> i64 attributes { llvm.emit_c_interface } - -// TODO: init with random, run and check output. -// func private @fill_random_f32(memref<*xf32>) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -609,7 +609,6 @@ linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); } tilingPattern.add, - linalg::LinalgTilingPattern, linalg::LinalgTilingPattern>( context, linalgTilingOptions, linalg::LinalgTransformationFilter(Identifier::get("tile", context))); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -62,7 +62,6 @@ 'mlir-capi-ir-test', 'mlir-capi-pass-test', 'mlir-cpu-runner', - 'mlir-linalg-ods-gen', 'mlir-linalg-ods-yaml-gen', 'mlir-reduce', ] diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc deleted file mode 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ /dev/null @@ -1,209 +0,0 @@ -// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS -// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL - -// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [ -// ODS-NEXT: AttrSizedOperandSegments -// ODS-NEXT: DeclareOpInterfaceMethods, -// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> -// -// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() { -// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } -// -// IMPL: ArrayAttr Test1Op::indexing_maps() { -// IMPL: auto s0 = getAffineSymbolExpr(0, context); (void)s0; -// IMPL-NEXT: auto s1 = getAffineSymbolExpr(1, context); (void)s1; -// IMPL-NEXT: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); -// IMPL-NEXT: map0 = map0.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); -// IMPL-NEXT: map0 = simplifyAffineMap(map0); -// IMPL-NEXT: auto map1 = AffineMap::get(2, 2, {d1}, context); -// IMPL-NEXT: map1 = map1.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); -// IMPL-NEXT: map1 = simplifyAffineMap(map1); -// IMPL-NEXT: auto map2 = AffineMap::get(2, 2, {d0}, context); -// IMPL-NEXT: map2 = map2.replaceDimsAndSymbols({}, { s0, s1 }, 2, 0); -// IMPL-NEXT: map2 = simplifyAffineMap(map2); -// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); -// -// IMPL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block) { -// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); -// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); -// IMPL: b.create(ValueRange{ [[e]] }); -// -ods_def : -def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { - C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); -} - -// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [ -// ODS-NEXT: AttrSizedOperandSegments -// ODS-NEXT: DeclareOpInterfaceMethods, -// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> -// -// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() { -// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } -// -// IMPL: ArrayAttr Test2Op::indexing_maps() { -// IMPL: AffineMap::get(3, 3, {d0, d2}, context) -// IMPL: AffineMap::get(3, 3, {d2, d1}, context) -// IMPL: AffineMap::get(3, 3, {d0, d1}, context) -// -// IMPL: Test2Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block) { -// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); -// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); -// IMPL: b.create(ValueRange{ [[e]] }); -// -ods_def : -def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { - C(m, n) = AddFOp(C(m, n), MulFOp(A(m, k), B(k, n))); -} - -// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [ -// ODS-NEXT: AttrSizedOperandSegments -// ODS-NEXT: DeclareOpInterfaceMethods, -// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> -// -// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() { -// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } -// -// IMPL: ArrayAttr Test3Op::indexing_maps() { -// IMPL: AffineMap::get(4, 4, {d0, d1, d3}, context) -// IMPL: AffineMap::get(4, 4, {d3, d2}, context) -// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) -// -// IMPL: Test3Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block) { -// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); -// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); -// IMPL: b.create(ValueRange{ [[e]] }); -// -ods_def : -def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { - C(b, m, n) = AddFOp(C(b, m, n), MulFOp(A(b, m, k), B(k, n))); -} - -// Test attribute definitions -// ODS-LABEL: def Test4Op -// ODS: F32ArrayAttr:$array_attr, -// ODS: F32Attr:$f32_attr, -// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr, -// ODS: I32Attr:$i32_attr, -// ODS: I64Attr:$i64_attr, -// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, -// ODS: OptionalAttr:$optional_attr -// -// ODS: bool hasDynamicIndexingMaps(); -// ODS: LogicalResult verifyIndexingMapRequiredAttributes(); -// -// IMPL: bool Test4Op::hasDynamicIndexingMaps() { return true; } -// IMPL: LogicalResult Test4Op::verifyIndexingMapRequiredAttributes() -// IMPL: op->getAttrOfType("array_attr") -// IMPL: op->getAttr("f32_attr") -// IMPL: op->getAttrOfType("fvec_attr") -// IMPL: op->getAttr("i32_attr") -// IMPL: op->getAttr("i64_attr") -// IMPL: op->getAttrOfType("ivec_attr") -// -ods_def : -def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) -attr( - f32_attr: f32, - i32_attr: i32, - i64_attr: i64, - fvec_attr: 4xf32, - ivec_attr: 5x6xi32, - array_attr : f32[], - optional_attr? : f32 -) { - C(b, m, n) = AddFOp(C(b, m, n), MulFOp(A(b, m, k), B(k, n))); -} - -// Test attribute usage in affine expressions -// IMPL-LABEL: ArrayAttr Test5Op::indexing_maps() { -// IMPL: auto cst0 = getAffineConstantExpr(strides().getValue({ 0 }), context); -// IMPL: auto cst1 = getAffineConstantExpr(strides().getValue({ 1 }), context); -// IMPL: auto map0 = AffineMap::get(7, 9, {d0, d1 * s7 + d4, d2 * s8 + d5, d6}, context); -// IMPL: map0 = map0.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); -// IMPL: map0 = simplifyAffineMap(map0); -// IMPL: auto map1 = AffineMap::get(7, 9, {d3, d4, d5, d6}, context); -// IMPL: map1 = map1.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); -// IMPL: map1 = simplifyAffineMap(map1); -// IMPL: auto map2 = AffineMap::get(7, 7, {d0, d1, d2, d3}, context); -// IMPL: map2 = map2.replaceDimsAndSymbols({}, { s0, s1, s2, s3, s4, s5, s6, cst0, cst1 }, 7, 0); -// IMPL: map2 = simplifyAffineMap(map2); -// IMPL: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); -// -ods_def: -def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) - attr(strides: 2xi32) { - O(n, h, w, f) = AddFOp( - MulFOp(AddFOp(I(n, h * strides[0] + kh, w * strides[1] + kw, c), - I(n, h * strides[0] + kh, w * strides[1] + kw, c)), - K(f, kh, kw, c))); -} - -// Test documentation -// ODS-LABEL: def Test6Op -// ODS: let summary = [{ My magic op. }]; -// ODS-NEXT: let description = [{ -// ODS-NEXT: It has two inputs. -// ODS-NEXT: It has one output. -// ODS-NEXT: }]; -// -ods_def: -def test6(A: f32(M, K), B: f32(K)) -> (C: f32(M)) -""" -My magic op. - -It has two inputs. -It has one output. -""" -{ - C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); -} - -// Test attribute builder -// ODS-LABEL: def Test7Op -// ODS: OpBuilder< -// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, -// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b, -// ODS: CArg<"ArrayRef", "{}">:$attributes) -// ODS: $_state.addAttribute("attr_a", attr_a); -// ODS: $_state.addAttribute("attr_b", attr_b); -// -ods_def: -def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M)) - attr(attr_a: f32, attr_b: 4xi32) -{ - C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); -} - -// Test output arg order. -// IMPL-LABEL: void Test8Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block) { -// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); -// IMPL: Value [[e:.*]] = b.create([[d]], [[c]]); -// IMPL: b.create(ValueRange{ [[e]] }); -ods_def: -def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M)) -{ - C(m) = SubFOp(MulFOp(A(m, k), B(k)), C(m)); -} - -// Test shape-only operand. -// IMPL-LABEL: ArrayAttr Test9Op::indexing_maps() { -// IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); -// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context); -// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context); -// IMPL-LABEL: void Test9Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block) { -// IMPL: Value [[a:.*]](args[0]), [[c:.*]](args[2]); -ods_def: -def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M)) -{ - C(m) = AddFOp(C(m), A(m, k)); -} diff --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt --- a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt +++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt @@ -3,42 +3,6 @@ Support ) -set(LLVM_OPTIONAL_SOURCES - mlir-linalg-ods-gen.cpp - mlir-linalg-ods-yaml-gen.cpp -) - -# Original mlir-linalg-ods-gen (to be replaced). -add_llvm_tool(mlir-linalg-ods-gen - mlir-linalg-ods-gen.cpp -) -llvm_update_compile_flags(mlir-linalg-ods-gen) -target_link_libraries(mlir-linalg-ods-gen PRIVATE - MLIRSupport - MLIRIR - ) - -set(MLIR_LINALG_ODS_GEN mlir-linalg-ods-gen CACHE - STRING "Native mlir-linalg-ods-gen executable. Saves building one when cross-compiling.") - -set(MLIR_LINALG_ODS_GEN_EXE ${MLIR_LINALG_ODS_GEN} PARENT_SCOPE) -set(MLIR_LINALG_ODS_GEN_TARGET mlir-linalg-ods-gen PARENT_SCOPE) - -if(LLVM_USE_HOST_TOOLS) - if (${MLIR_LINALG_ODS_GEN} STREQUAL "mlir-linalg-ods-gen") - build_native_tool(mlir-linalg-ods-gen MLIR_LINALG_ODS_GEN_EXE DEPENDS mlir-linalg-ods-gen) - set(MLIR_LINALG_ODS_GEN_EXE ${MLIR_LINALG_ODS_GEN_EXE} PARENT_SCOPE) - - add_custom_target(mlir-linalg-ods-gen-host DEPENDS ${MLIR_LINALG_ODS_GEN_EXE}) - set(MLIR_LINALG_ODS_GEN_TARGET mlir-linalg-ods-gen-host DEPENDS PARENT_SCOPE) - - if(NOT LLVM_BUILD_UTILS) - set_target_properties(mlir-linalg-ods-gen PROPERTIES EXCLUDE_FROM_ALL ON) - endif() - endif() -endif() - - # New mlir-linalg-ods-yaml-gen. add_llvm_tool(mlir-linalg-ods-yaml-gen mlir-linalg-ods-yaml-gen.cpp diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp deleted file mode 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ /dev/null @@ -1,2472 +0,0 @@ -//===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains the implementation for the Tensor Comprehension-inspired -// parser and ODS pretty-printer for specifying Linalg "named ops" from a -// mathematical form. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Support/FileUtilities.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" - -#include -#include - -#define DEBUG_TYPE "linalg-ods-gen" - -static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen"); - -// Commandline options -static llvm::cl::opt - inputFilename(llvm::cl::Positional, llvm::cl::desc(""), - llvm::cl::init("-"), llvm::cl::value_desc("filename")); - -static llvm::cl::opt - outputFilename("o", llvm::cl::desc("Output filename"), - llvm::cl::value_desc("filename"), llvm::cl::init("-")); - -static llvm::cl::opt - genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."), - llvm::cl::cat(ODSGenCat)); - -static llvm::cl::opt - genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"), - llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); - -static llvm::cl::opt testEmitIncludeTdHeader( - "test-emit-include-td-header", - llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end " - "tblgen testing."), - llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); - -using llvm::SMLoc; -using llvm::StringRef; -using llvm::Twine; - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// Special "op aliases" substitutions. -//===----------------------------------------------------------------------===// - -/// Perform substitutions of known special ops. -/// This is a poor man's way of achieving "op aliases": i.e. giving an op a -/// name. -/// This is hacky and temporary until migration to the python opdsl is complete. -static void substituteOpAliases(std::string &expressionsStr) { - for (auto kvp : SmallVector>{ - {"b.create(", "b.create(CmpIPredicate::sgt, "}, - {"b.create(", "b.create(CmpFPredicate::OGT, "}, - {"b.create(", "b.create(CmpFPredicate::OLT, "}, - {"b.create(", - "b.create(b.getI32Type(), "}, - }) { - size_t pos = 0; - while ((pos = expressionsStr.find(kvp.first, pos)) != std::string::npos) { - expressionsStr.replace(pos, kvp.first.size(), kvp.second); - pos += kvp.second.size(); - } - } -} - -//===----------------------------------------------------------------------===// -// Lexer -//===----------------------------------------------------------------------===// - -namespace { -/// This class represents a specific token in the input format. -class Token { -public: - enum class Kind { - // Markers. - eof, - error, - - // Tokens with no info. - colon, - comma, - doc_str, - equal, - gt, - l_brace, - l_paren, - l_square, - lt, - minus, - plus, - question, - r_brace, - r_paren, - r_square, - semicolon, - star, - - // Keywords. - kw_def, - FIRST_KEYWORD = kw_def, - kw_ods_def, - kw_implements_interface, - kw_attr_def, - kw_floordiv, - kw_ceildiv, - kw_mod, - LAST_KEYWORD = kw_mod, - - // String valued tokens. - id, - integer, - }; - - Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} - - /// Return the bytes that make up this token. - StringRef getSpelling() const { return spelling; } - - /// Return the kind of this token. - Kind getKind() const { return kind; } - - /// Return a location for this token. - llvm::SMLoc getLoc() const { - return llvm::SMLoc::getFromPointer(spelling.data()); - } - - /// Return if this token is a keyword. - bool isKeyword() const { - return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD; - } - bool is(Kind k) const { return kind == k; } - bool isNot(Kind k) const { return kind != k; } - - Optional getUInt64IntegerValue() const { - bool isHex = spelling.size() > 1 && spelling[1] == 'x'; - - uint64_t result = 0; - if (spelling.getAsInteger(isHex ? 0 : 10, result)) - return None; - return result; - } - -private: - /// Discriminator that indicates the kind of token this is. - Kind kind; - - /// A reference to the entire token contents; this is always a pointer into - /// a memory buffer owned by the source manager. - StringRef spelling; -}; - -/// This class implements a simple lexer. -class Lexer { -public: - Lexer(llvm::SourceMgr &mgr); - - /// Lex the next token and return it. - Token lexToken(); - - /// Emit an error to the lexer with the given location and message. - Token emitError(llvm::SMLoc loc, const Twine &msg); - Token emitError(const char *loc, const Twine &msg); - - /// Change the position of the lexer cursor. The next token we lex will start - /// at the designated point in the input. - void resetPointer(const char *newPtr) { curPtr = newPtr; } - -private: - Token formToken(Token::Kind kind, const char *tokStart) { - return Token(kind, StringRef(tokStart, curPtr - tokStart)); - } - - /// Return the next character in the stream. - int getNextChar(); - - /// Lex an identifier. - Token lexIdentifier(const char *tokStart); - - // Lex an integer. - Token lexInteger(const char *tokStart); - - // Lex a string. - Token lexString(const char *tokStart); - - // Skip a comment line, starting with a '//'. - void skipComment(); - - llvm::SourceMgr &srcMgr; - StringRef curBuffer; - const char *curPtr; -}; -} // end anonymous namespace - -Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) { - curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); - curPtr = curBuffer.begin(); -} - -Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) { - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - return formToken(Token::Kind::error, loc.getPointer()); -} -Token Lexer::emitError(const char *loc, const Twine &msg) { - return emitError(llvm::SMLoc::getFromPointer(loc), msg); -} - -int Lexer::getNextChar() { - char curChar = *curPtr++; - switch (curChar) { - default: - return (unsigned char)curChar; - case 0: { - // A nul character in the stream is either the end of the current buffer - // or a random nul in the file. Disambiguate that here. - if (curPtr - 1 != curBuffer.end()) - return 0; - - // Otherwise, return end of file. - --curPtr; - return EOF; - } - case '\n': - case '\r': - // Handle the newline character by ignoring it and incrementing the line - // count. However, be careful about 'dos style' files with \n\r in them. - // Only treat a \n\r or \r\n as a single line. - if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) - ++curPtr; - return '\n'; - } -} - -Token Lexer::lexToken() { - while (true) { - const char *tokStart = curPtr; - - // This always consumes at least one character. - int curChar = getNextChar(); - switch (curChar) { - default: - // Handle identifiers: [a-zA-Z_] - if (isalpha(curChar) || curChar == '_') - return lexIdentifier(tokStart); - - // Handle integers: [0-9] - if (isdigit(curChar)) - return lexInteger(tokStart); - - // Unknown character, emit an error. - return emitError(tokStart, "unexpected character"); - - case EOF: - // Return EOF denoting the end of lexing. - return formToken(Token::Kind::eof, tokStart); - - // Lex punctuation. - case ':': - return formToken(Token::Kind::colon, tokStart); - case ',': - return formToken(Token::Kind::comma, tokStart); - case '=': - return formToken(Token::Kind::equal, tokStart); - case '{': - return formToken(Token::Kind::l_brace, tokStart); - case '(': - return formToken(Token::Kind::l_paren, tokStart); - case '[': - return formToken(Token::Kind::l_square, tokStart); - case '}': - return formToken(Token::Kind::r_brace, tokStart); - case ')': - return formToken(Token::Kind::r_paren, tokStart); - case ']': - return formToken(Token::Kind::r_square, tokStart); - case '<': - return formToken(Token::Kind::lt, tokStart); - case '>': - return formToken(Token::Kind::gt, tokStart); - case '+': - return formToken(Token::Kind::plus, tokStart); - case '-': - return formToken(Token::Kind::minus, tokStart); - case ';': - return formToken(Token::Kind::semicolon, tokStart); - case '*': - return formToken(Token::Kind::star, tokStart); - case '?': - return formToken(Token::Kind::question, tokStart); - case '"': - return lexString(tokStart); - case '/': - if (*curPtr == '/') { - skipComment(); - continue; - } - // Unknown character, emit an error. - return emitError(tokStart, "unexpected character: not a comment"); - - // Ignore whitespace characters. - case 0: - case ' ': - case '\t': - case '\n': - return lexToken(); - } - } -} - -Token Lexer::lexIdentifier(const char *tokStart) { - // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* - while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') - ++curPtr; - - // Check to see if this identifier is a keyword. - StringRef str(tokStart, curPtr - tokStart); - Token::Kind kind = - StringSwitch(str) - .Case("attr", Token::Kind::kw_attr_def) - .Case("def", Token::Kind::kw_def) - .Case("ods_def", Token::Kind::kw_ods_def) - .Case("implements_interface", Token::Kind::kw_implements_interface) - .Case("floordiv", Token::Kind::kw_floordiv) - .Case("ceildiv", Token::Kind::kw_ceildiv) - .Case("mod", Token::Kind::kw_mod) - .Default(Token::Kind::id); - - return Token(kind, str); -} - -Token Lexer::lexInteger(const char *tokStart) { - // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* - while (isdigit(*curPtr)) - ++curPtr; - - StringRef str(tokStart, curPtr - tokStart); - return Token(Token::Kind::integer, str); -} - -Token Lexer::lexString(const char *tokStart) { - assert(curPtr[-1] == '"'); - - if (*curPtr == '"' && *(curPtr + 1) == '"') { - curPtr += 2; - while (true) { - switch (*curPtr++) { - case '"': - if (*curPtr == '"' && *(curPtr + 1) == '"') { - Token token(Token::Kind::doc_str, - StringRef(tokStart + 3, curPtr - tokStart - 4)); - curPtr += 2; - return token; - } - continue; - case 0: - // If this is a random nul character in the middle of the doc string, - // just include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) - continue; - return emitError(curPtr - 1, "expected '\"\"\"' to end doc string"); - default: - continue; - } - } - } - - return emitError(curPtr - 1, "expected '\"\"\"' to start doc string"); -} - -/// Skip a comment line, starting with a '//'. -void Lexer::skipComment() { - // Advance over the second '/' in a '//' comment. - assert(*curPtr == '/'); - ++curPtr; - - while (true) { - switch (*curPtr++) { - case '\n': - case '\r': - // Newline is end of comment. - return; - case 0: - // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { - --curPtr; - return; - } - LLVM_FALLTHROUGH; - default: - // Skip over other characters. - break; - } - } -} - -namespace { - -class Parser { -public: - Parser(llvm::SourceMgr &mgr, MLIRContext *ctx) - : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {} - - //===--------------------------------------------------------------------===// - // Lexer Utilities - //===--------------------------------------------------------------------===// - - LogicalResult parseInteger(uint64_t &value) { - if (!curToken.is(Token::Kind::integer)) - return emitError(curToken.getLoc(), "expected integer"); - value = curToken.getUInt64IntegerValue().getValue(); - consumeToken(); - return success(); - } - - /// Advance the current lexer onto the next token. - void consumeToken() { - assert(curToken.getKind() != Token::Kind::eof && - curToken.getKind() != Token::Kind::error && - "shouldn't advance past EOF or errors"); - curToken = lexer.lexToken(); - } - - void consumeToken(Token::Kind kind) { - assert(curToken.getKind() == kind && "unexpected token"); - curToken = lexer.lexToken(); - } - - LogicalResult parseToken(Token::Kind kind, const Twine &msg) { - if (curToken.getKind() != kind) - return emitError(curToken.getLoc(), msg); - consumeToken(); - return success(); - } - - /// Parses an optional token and returns failure if failed to parse. - LogicalResult parseOptionalToken(Token::Kind kind) { - return success(consumeIf(kind)); - } - - LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { - lexer.emitError(loc, msg); - return failure(); - } - - LogicalResult emitError(const Twine &msg) { - return emitError(curToken.getLoc(), msg); - } - - bool consumeIf(Token::Kind kind) { - if (curToken.isNot(kind)) - return false; - consumeToken(kind); - return true; - } - - LogicalResult - parseCommaSeparatedList(llvm::function_ref parseElement) { - // Non-empty case starts with an element. - if (parseElement()) - return failure(); - - // Otherwise we have a list of comma separated elements. - while (consumeIf(Token::Kind::comma)) { - if (parseElement()) - return failure(); - } - return success(); - } - - LogicalResult - parseCommaSeparatedListUntil(Token::Kind rightToken, - llvm::function_ref parseElement, - bool allowEmptyList) { - // Handle the empty case. - if (curToken.is(rightToken)) { - if (!allowEmptyList) - return emitError("expected list element"); - consumeToken(rightToken); - return success(); - } - - if (failed(parseCommaSeparatedList(parseElement)) || - failed( - parseToken(rightToken, "expected ',' or right-terminating token"))) - return failure(); - - return success(); - } - - Lexer lexer; - Token curToken; - MLIRContext *context; -}; -} // namespace - -/// Encodes an attribute use of the form: -/// -/// index-list ::= integer-literal (`,` integer-literal)* -/// attr-use ::= bare-id `[` index-list `]` -struct AttrUse { - // Referenced attribute - StringRef attrName; - // Indices into the attribute - SmallVector indices; - /// Affine symbol for this usage. - /// This is represented as an affine symbol because at the time of parsing the - /// spec and generating the op's ODS/C++, we don't know the concrete constant - /// value. But they should be replaced with constants read from the attribute - /// and thus folded away for concrete op instances. - AffineExpr symbol; - - std::string getKey() { - SmallVector indexStrs; - for (uint64_t index : indices) - indexStrs.push_back(std::to_string(index)); - return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ",")); - } -}; - -//===----------------------------------------------------------------------===// -// Affine parsing. -//===----------------------------------------------------------------------===// - -namespace { - -/// Lower precedence ops (all at the same precedence level). LNoOp is false in -/// the boolean sense. -enum AffineLowPrecOp { - /// Null value. - LNoOp, - Add, - Sub -}; - -/// Higher precedence ops - all at the same precedence level. HNoOp is false -/// in the boolean sense. -enum AffineHighPrecOp { - /// Null value. - HNoOp, - Mul, - FloorDiv, - CeilDiv, - Mod -}; - -using AffineDimList = SmallVector, 4>; -using AffineSymbolList = SmallVector, 4>; - -/// This is a specialized parser for affine expressions. -class AffineParser { -public: - /// Creates an affine parser that parses tokens from `p`. - /// - /// The affine parser introduces new dimensions and symbols eagerly as new - /// `id` are discovered. To additionally support attribute use `id`s, for a - /// parsed `id`, the resolution mechanism proceeds as follows: - /// 1. Try to parse `id` as an attribute use (using the `attrUseParsingHook`). - /// 2. If unsuccessful, try to match `id` to a known dim or symbol. - /// 3. If still unsuccessful, eagerly create a new dim or symbol and add it to - /// the known dims or symbols (using the `bareIdParsingHook`). - explicit AffineParser( - Parser &p, std::function bareIdParsingHook, - std::function()> attrUseParsingHook, - AffineDimList &dimList, AffineSymbolList &symbolList) - : parser(p), bareIdFallback(bareIdParsingHook), - attrUseCallback(attrUseParsingHook), dims(dimList), - symbols(symbolList) {} - - /// Parse a comma-separated list of affine exprs. - SmallVector - parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren, - Token::Kind rDelim = Token::Kind::r_paren); - - /// Parse a single affine expr.`. - AffineExpr parseAffineExpr(); - -private: - // Binary affine op parsing. - AffineLowPrecOp consumeIfLowPrecOp(); - AffineHighPrecOp consumeIfHighPrecOp(); - - // AffineExpr parsing. - AffineExpr parseParentheticalExpr(); - AffineExpr parseNegateExpression(AffineExpr lhs); - AffineExpr parseIntegerExpr(); - AffineExpr parseAttrUseOrBareIdExpr(); - AffineExpr parseBareIdExpr(); - - AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, - AffineExpr rhs, SMLoc opLoc); - AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, - AffineExpr rhs); - AffineExpr parseAffineOperandExpr(AffineExpr lhs); - AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); - AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc); - - Parser &parser; - std::function bareIdFallback; - std::function()> attrUseCallback; - AffineDimList &dims; - AffineSymbolList &symbols; -}; -} // end anonymous namespace - -/// Create an affine binary high precedence op expression (mul's, div's, mod). -/// opLoc is the location of the op token to be used to report errors -/// for non-conforming expressions. -AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, - AffineExpr lhs, AffineExpr rhs, - SMLoc opLoc) { - switch (op) { - case Mul: - if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { - (void)parser.emitError( - opLoc, "non-affine expression: at least one of the multiply " - "operands has to be either a constant or symbolic"); - return nullptr; - } - return lhs * rhs; - case FloorDiv: - if (!rhs.isSymbolicOrConstant()) { - (void)parser.emitError(opLoc, - "non-affine expression: right operand of floordiv " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs.floorDiv(rhs); - case CeilDiv: - if (!rhs.isSymbolicOrConstant()) { - (void)parser.emitError(opLoc, - "non-affine expression: right operand of ceildiv " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs.ceilDiv(rhs); - case Mod: - if (!rhs.isSymbolicOrConstant()) { - (void)parser.emitError(opLoc, - "non-affine expression: right operand of mod " - "has to be either a constant or symbolic"); - return nullptr; - } - return lhs % rhs; - case HNoOp: - llvm_unreachable("can't create affine expression for null high prec op"); - return nullptr; - } - llvm_unreachable("Unknown AffineHighPrecOp"); -} - -/// Create an affine binary low precedence op expression (add, sub). -AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, - AffineExpr lhs, AffineExpr rhs) { - switch (op) { - case AffineLowPrecOp::Add: - return lhs + rhs; - case AffineLowPrecOp::Sub: - return lhs - rhs; - case AffineLowPrecOp::LNoOp: - llvm_unreachable("can't create affine expression for null low prec op"); - return nullptr; - } - llvm_unreachable("Unknown AffineLowPrecOp"); -} - -/// Consume this token if it is a lower precedence affine op (there are only -/// two precedence levels). -AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { - switch (parser.curToken.getKind()) { - case Token::Kind::plus: - parser.consumeToken(); - return AffineLowPrecOp::Add; - case Token::Kind::minus: - parser.consumeToken(); - return AffineLowPrecOp::Sub; - default: - return AffineLowPrecOp::LNoOp; - } -} - -/// Consume this token if it is a higher precedence affine op (there are only -/// two precedence levels) -AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { - switch (parser.curToken.getKind()) { - case Token::Kind::star: - parser.consumeToken(Token::Kind::star); - return Mul; - case Token::Kind::kw_floordiv: - parser.consumeToken(Token::Kind::kw_floordiv); - return FloorDiv; - case Token::Kind::kw_ceildiv: - parser.consumeToken(Token::Kind::kw_ceildiv); - return CeilDiv; - case Token::Kind::kw_mod: - parser.consumeToken(Token::Kind::kw_mod); - return Mod; - default: - return HNoOp; - } -} - -/// Parse a high precedence op expression list: mul, div, and mod are high -/// precedence binary ops, i.e., parse a -/// expr_1 op_1 expr_2 op_2 ... expr_n -/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). -/// All affine binary ops are left associative. -/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is -/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is -/// null. llhsOpLoc is the location of the llhsOp token that will be used to -/// report an error for non-conforming expressions. -AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, - AffineHighPrecOp llhsOp, - SMLoc llhsOpLoc) { - AffineExpr lhs = parseAffineOperandExpr(llhs); - if (!lhs) - return nullptr; - - // Found an LHS. Parse the remaining expression. - auto opLoc = parser.curToken.getLoc(); - if (AffineHighPrecOp op = consumeIfHighPrecOp()) { - if (llhs) { - AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); - if (!expr) - return nullptr; - return parseAffineHighPrecOpExpr(expr, op, opLoc); - } - // No LLHS, get RHS - return parseAffineHighPrecOpExpr(lhs, op, opLoc); - } - - // This is the last operand in this expression. - if (llhs) - return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); - - // No llhs, 'lhs' itself is the expression. - return lhs; -} - -/// Parse an affine expression inside parentheses. -/// -/// affine-expr ::= `(` affine-expr `)` -AffineExpr AffineParser::parseParentheticalExpr() { - if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) - return nullptr; - if (parser.curToken.is(Token::Kind::r_paren)) - return ((void)parser.emitError("no expression inside parentheses"), - nullptr); - - auto expr = parseAffineExpr(); - if (!expr) - return nullptr; - if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'"))) - return nullptr; - - return expr; -} - -/// Parse the negation expression. -/// -/// affine-expr ::= `-` affine-expr -AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { - if (failed(parser.parseToken(Token::Kind::minus, "expected '-'"))) - return nullptr; - - AffineExpr operand = parseAffineOperandExpr(lhs); - // Since negation has the highest precedence of all ops (including high - // precedence ops) but lower than parentheses, we are only going to use - // parseAffineOperandExpr instead of parseAffineExpr here. - if (!operand) - // Extra error message although parseAffineOperandExpr would have - // complained. Leads to a better diagnostic. - return ((void)parser.emitError("missing operand of negation"), nullptr); - return (-1) * operand; -} - -AffineExpr AffineParser::parseAttrUseOrBareIdExpr() { - if (llvm::Optional attrUse = attrUseCallback()) - return attrUse.getValue(); - return parseBareIdExpr(); -} - -/// Parse a bare id that may appear in an affine expression. -/// -/// affine-expr ::= bare-id -AffineExpr AffineParser::parseBareIdExpr() { - if (parser.curToken.isNot(Token::Kind::id)) - return ((void)parser.emitError("expected id"), nullptr); - - StringRef sRef = parser.curToken.getSpelling(); - for (auto &list : {dims, symbols}) { - for (auto entry : list) { - if (entry.first == sRef) { - parser.consumeToken(Token::Kind::id); - return entry.second; - } - } - } - - // Not found, check fallback path. - AffineExpr expr = bareIdFallback(sRef); - if (expr) { - parser.consumeToken(Token::Kind::id); - return expr; - } - - return ((void)parser.emitError("use of undeclared id"), nullptr); -} - -/// Parse a positive integral constant appearing in an affine expression. -/// -/// affine-expr ::= integer-literal -AffineExpr AffineParser::parseIntegerExpr() { - auto val = parser.curToken.getUInt64IntegerValue(); - if (!val.hasValue() || (int64_t)val.getValue() < 0) - return ((void)parser.emitError("constant too large for index"), nullptr); - - parser.consumeToken(Token::Kind::integer); - return getAffineConstantExpr((int64_t)val.getValue(), parser.context); -} - -/// Parses an expression that can be a valid operand of an affine expression. -/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary -/// operator, the rhs of which is being parsed. This is used to determine -/// whether an error should be emitted for a missing right operand. -// Eg: for an expression without parentheses (like i + j + k + l), each -// of the four identifiers is an operand. For i + j*k + l, j*k is not an -// operand expression, it's an op expression and will be parsed via -// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -// -l are valid operands that will be parsed by this function. -AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { - switch (parser.curToken.getKind()) { - case Token::Kind::id: - return parseAttrUseOrBareIdExpr(); - case Token::Kind::integer: - return parseIntegerExpr(); - case Token::Kind::l_paren: - return parseParentheticalExpr(); - case Token::Kind::minus: - return parseNegateExpression(lhs); - case Token::Kind::kw_ceildiv: - case Token::Kind::kw_floordiv: - case Token::Kind::kw_mod: - case Token::Kind::plus: - case Token::Kind::star: - if (lhs) - (void)parser.emitError("missing right operand of binary operator"); - else - (void)parser.emitError("missing left operand of binary operator"); - return nullptr; - default: - if (lhs) - (void)parser.emitError("missing right operand of binary operator"); - else - (void)parser.emitError("expected affine expression"); - return nullptr; - } -} - -/// Parse affine expressions that are bare-id's, integer constants, -/// parenthetical affine expressions, and affine op expressions that are a -/// composition of those. -/// -/// All binary op's associate from left to right. -/// -/// {add, sub} have lower precedence than {mul, div, and mod}. -/// -/// Add, sub'are themselves at the same precedence level. Mul, floordiv, -/// ceildiv, and mod are at the same higher precedence level. Negation has -/// higher precedence than any binary op. -/// -/// llhs: the affine expression appearing on the left of the one being parsed. -/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, -/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned -/// if llhs is non-null; otherwise lhs is returned. This is to deal with left -/// associativity. -/// -/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function -/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where -/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). -AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, - AffineLowPrecOp llhsOp) { - AffineExpr lhs; - if (!(lhs = parseAffineOperandExpr(llhs))) - return nullptr; - - // Found an LHS. Deal with the ops. - if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { - if (llhs) { - AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); - return parseAffineLowPrecOpExpr(sum, lOp); - } - // No LLHS, get RHS and form the expression. - return parseAffineLowPrecOpExpr(lhs, lOp); - } - auto opLoc = parser.curToken.getLoc(); - if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { - // We have a higher precedence op here. Get the rhs operand for the llhs - // through parseAffineHighPrecOpExpr. - AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); - if (!highRes) - return nullptr; - - // If llhs is null, the product forms the first operand of the yet to be - // found expression. If non-null, the op to associate with llhs is llhsOp. - AffineExpr expr = - llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; - - // Recurse for subsequent low prec op's after the affine high prec op - // expression. - if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) - return parseAffineLowPrecOpExpr(expr, nextOp); - return expr; - } - // Last operand in the expression list. - if (llhs) - return getAffineBinaryOpExpr(llhsOp, llhs, lhs); - // No llhs, 'lhs' itself is the expression. - return lhs; -} - -/// Parse an affine expression. -/// affine-expr ::= `(` affine-expr `)` -/// | `-` affine-expr -/// | affine-expr `+` affine-expr -/// | affine-expr `-` affine-expr -/// | affine-expr `*` affine-expr -/// | affine-expr `floordiv` affine-expr -/// | affine-expr `ceildiv` affine-expr -/// | affine-expr `mod` affine-expr -/// | bare-id -/// | integer-literal -/// -/// Additional conditions are checked depending on the production. For eg., -/// one of the operands for `*` has to be either constant/symbolic; the second -/// operand for floordiv, ceildiv, and mod has to be a positive integer. -AffineExpr AffineParser::parseAffineExpr() { - return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); -} - -SmallVector AffineParser::parseAffineExprs(Token::Kind lDelim, - Token::Kind rDelim) { - if (failed(parser.parseToken(lDelim, - "expected lDelim at start of affine expr list"))) - return {}; - - SmallVector exprs; - auto parseElt = [&]() -> LogicalResult { - auto elt = parseAffineExpr(); - exprs.push_back(elt); - return elt ? success() : failure(); - }; - - if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt, - /*allowEmptyList=*/true))) - llvm_unreachable("Failed AffineExpr parsing"); - - return exprs; -} - -//===----------------------------------------------------------------------===// -// TC parsing. -//===----------------------------------------------------------------------===// - -namespace { - -/// Base class for expressions involved in TC parsing. -struct Expression { - enum class Kind { - Uninitialized = 0, - TensorExpr = 1, - TensorUse = 2, - }; - - explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {} - virtual ~Expression() = default; - - operator bool() const { return kind != Kind::Uninitialized; } - - Kind kind; -}; - -/// Encodes a tensor use of the form: -/// -/// affine-expr-list ::= affine-expr (`,` affine-expr)* -/// tensor-use ::= bare-id `(` `)` -/// | bare-id `(` affine-expr-list `)` -/// -/// The affine-expr-list is stored as an AffineMap. -struct TensorUse : public Expression { - TensorUse() : TensorUse("", AffineMap()) {} - TensorUse(StringRef name, AffineMap map) - : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {} - - static bool classof(const Expression *e) { - return e->kind == Kind::TensorUse; - } - - bool operator==(const TensorUse &other) const { - return tensorId == other.tensorId && indexingMap == other.indexingMap; - } - - /// Visitation function. Performs preorder or postorder traversal depending on - /// `PreOrder` and applies `callback` on each node. - template void visit(Lambda callback) const; - - StringRef tensorId; - AffineMap indexingMap; -}; - -/// Encodes a tensor expression of the form: -/// -/// op-spec ::= bare-id `<` reduction-dims-list `>` -/// | bare-id -/// op-arg ::= tensor-expr -/// | tensor-use -/// op-arg-list ::= op-arg (`,` op-arg)* -/// tensor-expr ::= op-spec `(` op-arg-list `)` -/// -/// Underlying op-arg are stored by unique_ptr to base class. -struct TensorExpr : public Expression { - TensorExpr(StringRef name, - SmallVectorImpl> &&exprs, - ArrayRef reductionDims) - : Expression(Kind::TensorExpr), operationName(name), - expressions(std::move(exprs)), - reductionDimensions(reductionDims.begin(), reductionDims.end()) {} - - static bool classof(const Expression *e) { - return e->kind == Kind::TensorExpr; - } - - bool operator==(const TensorExpr &other) const { - if (operationName != other.operationName) - return false; - if (expressions.size() != other.expressions.size()) - return false; - for (unsigned i = 0, e = expressions.size(); i < e; ++i) - if (*expressions[i] != *other.expressions[i]) - return false; - for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i) - if (reductionDimensions[i] != other.reductionDimensions[i]) - return false; - return true; - } - - /// Visitation function. Performs preorder or postorder traversal depending on - /// `PreOrder` and applies `callback` on each node. - template void visit(Lambda callback) const; - - StringRef operationName; - SmallVector, 4> expressions; - SetVector reductionDimensions; -}; - -/// This is a specialized parser for a TCDef. -/// This maintains the dims it finds in an eager fashion. -class TCParser { - enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions }; - -public: - explicit TCParser(Parser &p); - - /// Uses the AffineParser to parse the affine exprs used in a tensor - /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new - /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is - /// emitted on new identifiers. - SmallVector - parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims, - Token::Kind lDelim = Token::Kind::l_paren, - Token::Kind rDelim = Token::Kind::r_paren); - - /// Parse the information for a tensor def. - /// All the affine-expr must be dimensionless (i.e. contain only expressions - /// involving symbols and constants), but can otherwise contain arbitrary - /// affine expressions. - LogicalResult parseTensorDef(bool isOutput); - - /// Parses a tensor use. - struct ComprehensionParsingState { - /// The number of operands (which includes inputs and outputs) in a - /// comprehension. - size_t numArgs; - AffineDimList dims; - SmallVector, 4> expressions; - llvm::DenseMap orderedTensorArgs; - }; - LogicalResult parseTensorUse(TensorUse &result, - ComprehensionParsingState &state); - - /// Parses an attribute definition. - LogicalResult parseAttrDef(); - - /// Parses an optional attribute use. - LogicalResult parseAttrUse(AttrUse &result); - - /// Parses a tensor expression. - LogicalResult parseExpression(TensorUse currentDefinition, - std::unique_ptr &result, - ComprehensionParsingState &state); - - /// Parse a single comprehension. - LogicalResult parseOneComprehension(StringRef cppOpName, - StringRef linalgOpName, - ComprehensionParsingState &state); - - /// Parse and print the information for a TC def. - /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC. - /// When `gen-impl` is used, this prints the C++ implementation for the extra - /// methods defined in ODS (`iterator_types`, `indexing_maps` and - /// `regionBuilder`). - LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os); - - /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. - void printODS(llvm::raw_ostream &os, StringRef cppOpName, - StringRef linalgOpName, ArrayRef interfaces, - ComprehensionParsingState &state); - - /// Print the C++ StructuredOpsInterface impl of `iterator_types`. - void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, - ComprehensionParsingState &state); - - /// Print methods related to indexing map required attributes. - /// - /// Specifically, this prints the definitions for the following methods: - /// bool hasDynamicIndexingMaps(); - /// LogicalResult verifyIndexingMapRequiredAttributes(); - void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os, - StringRef cppOpName, - ComprehensionParsingState &state); - - /// Print the C++ StructuredOpsInterface impl of `indexing_maps`. - void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, - ComprehensionParsingState &state); - - /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. - void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, - ComprehensionParsingState &state); - - /// Print the C++ impl for named ops canonicalizers and folders. - void printCanonicalizersAndFolders(llvm::raw_ostream &os, - StringRef cppOpName); - -private: - //===--------------------------------------------------------------------===// - // Internal bookkeeping of tensors. - //===--------------------------------------------------------------------===// - struct RegisteredTensor { - StringRef type; - AffineMap shape; - bool isOutput; - AffineMap indexingMap; - unsigned index; - }; - - //===--------------------------------------------------------------------===// - // Internal bookkeeping of attributes. - //===--------------------------------------------------------------------===// - struct RegisteredAttr { - StringRef elementType; - SmallVector vectorDims; - bool isArray; - bool isOptional; - - // Returns the function to get values at the given indices from this - // attribute. - llvm::Optional getValueFn(ArrayRef indices) const; - }; - - //===--------------------------------------------------------------------===// - // Per-TC def state. - //===--------------------------------------------------------------------===// - /// Symbols are per TC def. - AffineSymbolList symbols; - - /// Attribute usages in all affine expressions. - SmallVector attrUses; - - /// Tensors are per TC def. - llvm::StringMap registeredTensors; - unsigned nextRegisteredTensorIndex; - - /// Attributes are per TC def. - std::map registeredAttrs; - - /// A map from AttrUse to AffineExpr symbol. - llvm::StringMap registeredAttrUseToSymbol; - - StringRef docString; - - Parser &parser; -}; -} // namespace - -namespace llvm { - -template <> struct DenseMapInfo { - static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); } - static TensorUse getTombstoneKey() { - return TensorUse(DenseMapInfo::getTombstoneKey(), - DenseMapInfo::getTombstoneKey()); - } - static unsigned getHashValue(const TensorUse &val) { - return ::llvm::hash_value(val.tensorId); // don't care about collisions. - } - static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) { - return LHS == RHS; - } -}; - -} // namespace llvm - -//===----------------------------------------------------------------------===// -// Visitation functions. -//===----------------------------------------------------------------------===// - -template -void visit(const Expression &expr, Lambda callback) { - switch (expr.kind) { - default: - llvm_unreachable("Unexpected kind"); - case Expression::Kind::TensorExpr: - static_cast(expr).visit(callback); - break; - case Expression::Kind::TensorUse: - static_cast(expr).visit(callback); - break; - } -} - -template -void visitPreorder(const Expression &expr, Lambda callback) { - visit(expr, callback); -} - -template -void visitPostorder(Expression &expr, Lambda callback) { - visit(expr, callback); -} - -template -void TensorExpr::visit(Lambda callback) const { - if (!PreOrder) - callback(*this); - for (auto &e : expressions) - ::visit(*e, callback); - if (PreOrder) - callback(*this); -} - -template -void TensorUse::visit(Lambda callback) const { - callback(*this); -} - -//===----------------------------------------------------------------------===// -// TC parsing functions. -//===----------------------------------------------------------------------===// -TCParser::TCParser(Parser &p) - : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {} - -/// Uses the AffineParser to parse the affine exprs used in a tensor -/// definition. All identifiers are interpreted as symbols, new symbols are -/// added eagerly. -SmallVector -TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode, - AffineDimList &dims, Token::Kind lDelim, - Token::Kind rDelim) { - auto createAffineBareId = [&](StringRef sRef) { - AffineExpr expr; - if (discoveryMode == EagerDiscoveryMode::Symbols) { - expr = getAffineSymbolExpr(symbols.size(), parser.context); - symbols.emplace_back(sRef, expr); - } else if (discoveryMode == EagerDiscoveryMode::Dimensions) { - expr = getAffineDimExpr(dims.size(), parser.context); - dims.emplace_back(sRef, expr); - } - return expr; - }; - - auto tryToParseAttrUse = [&]() -> llvm::Optional { - if (!parser.curToken.is(Token::Kind::id)) - return llvm::None; - - StringRef attrName = parser.curToken.getSpelling(); - auto it = registeredAttrs.find(attrName.str()); - if (it == registeredAttrs.end()) - return llvm::None; - - AttrUse result; - if (failed(parseAttrUse(result))) - return llvm::None; - - auto symbolIt = registeredAttrUseToSymbol.find(result.getKey()); - if (symbolIt == registeredAttrUseToSymbol.end()) { - result.symbol = getAffineSymbolExpr(symbols.size(), parser.context); - symbols.emplace_back("", result.symbol); - registeredAttrUseToSymbol[result.getKey()] = result.symbol; - attrUses.push_back(result); - } else { - result.symbol = symbolIt->second; - } - - return result.symbol; - }; - - AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims, - symbols); - return affineParser.parseAffineExprs(lDelim, rDelim); -} - -/// Parse the information for a tensor def of the form: -/// -/// affine-expr-list ::= affine-expr (`,` affine-expr )* -/// tensor-typedef ::= type `(` `)` -/// | type `(` affine-expr-list `)` -/// tensor-def ::= bare-id `:` tensor-typedef -LogicalResult TCParser::parseTensorDef(bool isOutput) { - StringRef tensorId = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id")) || - failed(parser.parseToken(Token::Kind::colon, "expected colon"))) - return failure(); - - StringRef tensorType = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - - AffineDimList emptyDims; - auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims); - assert(emptyDims.empty() && "Unexpected dimension in tensor def"); - AffineMap map = - AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context); - - auto iterBoolPair = registeredTensors.try_emplace( - tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(), - nextRegisteredTensorIndex++}); - (void)iterBoolPair; - assert(iterBoolPair.second && "Could not emplace tensor registration"); - LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " " - << "with typeString: " << tensorType << " " - << "and shape: " << map << "\n"); - - return success(); -} - -/// Parses a tensor use of the form: -/// -/// affine-expr-list ::= affine-expr (`,` affine-expr)* -/// tensor-use ::= bare-id `(` `)` -/// | bare-id `(` affine-expr-list `)` -LogicalResult TCParser::parseTensorUse(TensorUse &result, - ComprehensionParsingState &state) { - StringRef tensorId = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - - auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims); - AffineMap map = - AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context); - LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map - << "\n"); - - result = TensorUse(tensorId, map); - return success(); -} - -/// Parse the information for an attribute def of the form: -/// -/// affine-expr-list ::= affine-expr (`,` affine-expr )* -/// attr-id ::= bare-id (`?`)? -/// dim-list ::= (integer-literal 'x')+ -/// attr-typedef ::= dim-list? type (`[` `]`)? -/// attr-def ::= attr-id `:` attr-typedef -LogicalResult TCParser::parseAttrDef() { - auto attrLoc = parser.curToken.getLoc(); - StringRef attrName = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question)); - if (failed(parser.parseToken(Token::Kind::colon, "expected colon"))) - return failure(); - - // Parse the attribute's type. We don't expect the type to be arbitrary - // complex, so just use this ad-hoc handling here. - - // Parse potential dimension list - SmallVector vectorDims; - while (parser.curToken.is(Token::Kind::integer)) { - uint64_t value; - if (failed(parser.parseInteger(value))) - return failure(); - vectorDims.push_back(value); - - StringRef spelling = parser.curToken.getSpelling(); - if (spelling[0] != 'x') - return parser.emitError(parser.curToken.getLoc(), - "expected 'x' in dimension list"); - - // If we had a prefix of 'x', lex the next token immediately after the 'x'. - if (spelling.size() != 1) - parser.lexer.resetPointer(spelling.data() + 1); - - parser.consumeToken(); - } - - StringRef elementType = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - - bool isArray = false; - auto arrayLoc = parser.curToken.getLoc(); - if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) { - isArray = true; - if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'"))) - return failure(); - } - - if (!vectorDims.empty() && isArray) - return parser.emitError(arrayLoc, "unsupported vector array attribute"); - - auto iterBoolPair = registeredAttrs.emplace( - attrName.str(), - RegisteredAttr{elementType, vectorDims, isArray, isOptional}); - if (!iterBoolPair.second) - return parser.emitError(attrLoc, - "Failed to register attribute '" + attrName + "'"); - - LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "") - << " " << attrName << " " - << "with type: " << elementType - << (isArray ? "[]" : "") << "\n"); - - return success(); -} - -LogicalResult TCParser::parseAttrUse(AttrUse &result) { - result.attrName = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - - auto it = registeredAttrs.find(result.attrName.str()); - assert(it != registeredAttrs.end()); - const RegisteredAttr &attr = it->second; - - if (!attr.vectorDims.empty() || attr.isArray) { - // This is a vector/array attribute. Parse indices for it. - auto indexLoc = parser.curToken.getLoc(); - - if (failed(parser.parseToken(Token::Kind::l_square, "expected '['"))) - return failure(); - - auto parseIndex = [&]() { - uint64_t value; - if (failed(parser.parseInteger(value))) - return failure(); - result.indices.push_back(value); - return success(); - }; - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false))) - return failure(); - - size_t rank = attr.isArray ? 1 : attr.vectorDims.size(); - if (result.indices.size() != rank) - return parser.emitError(indexLoc, - "number of indices mismatch: expected " + - std::to_string(rank) + ", but found " + - std::to_string(result.indices.size())); - } - - return success(); -} - -/// Parses a tensor expression of the form: -/// -/// op-spec ::= bare-id `<` reduction-dims-list `>` -/// | bare-id -/// op-arg ::= tensor-expr -/// | tensor-use -/// op-arg-list ::= op-arg (`,` op-arg)* -/// tensor-expr ::= op-spec `(` op-arg-list `)` -LogicalResult TCParser::parseExpression(TensorUse currentDefinition, - std::unique_ptr &result, - ComprehensionParsingState &state) { - StringRef opOrTensor = parser.curToken.getSpelling(); - if (registeredTensors.count(opOrTensor) > 0) { - TensorUse use; - auto res = parseTensorUse(use, state); - if (failed(res)) - return res; - result = std::make_unique(use); - return success(); - } - - if (failed(parser.parseToken(Token::Kind::id, "expected an operation"))) - return failure(); - - // This is an op. - SmallVector reductionDims; - SmallVector, 4> expressions; - - // Check if it has a reduction set, discover dimensions eagerly. - if (parser.curToken.is(Token::Kind::lt)) { - auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims, - Token::Kind::lt, Token::Kind::gt); - for (auto iter : iters) - reductionDims.push_back(iter.cast().getPosition()); - } - - auto parseExpr = [&]() -> LogicalResult { - std::unique_ptr e; - if (failed(parseExpression(currentDefinition, e, state))) - return failure(); - expressions.push_back(std::move(e)); - return success(); - }; - if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) || - failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true))) - return failure(); - - result = std::make_unique(opOrTensor, std::move(expressions), - reductionDims); - - return success(); -} - -//===----------------------------------------------------------------------===// -// Parse and Emit functions. -//===----------------------------------------------------------------------===// - -/// Parse the information for a single comprehension. -/// -/// tensor-def-list ::= tensor-def (`,` tensor-def)* -/// tensor-expr-list ::= tensor-expr (`,` tensor-expr)* -/// comprehension ::= tensor-def-list `=` tensor-expr-list `;` -LogicalResult -TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName, - ComprehensionParsingState &state) { - // 1. Parse LHS of `=`, these become the definitions that appear as the output - // tensors or read/write buffers. - SmallVector definitions; - auto parseUse = [&]() -> LogicalResult { - TensorUse use; - if (failed(parseTensorUse(use, state))) - return failure(); - definitions.push_back(use); - return success(); - }; - if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse, - /*allowEmptyList=*/true))) - return failure(); - - // 2. Parse RHS of `=`, this becomes the expressions from which we emit - // computations. - unsigned idx = 0; - auto parseExpr = [&]() -> LogicalResult { - std::unique_ptr expr; - if (idx >= definitions.size()) - return parser.emitError("Fewer LHS definitions than RHS expressions"); - if (failed(parseExpression(definitions[idx++], expr, state))) - return failure(); - state.expressions.push_back(std::move(expr)); - return success(); - }; - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true))) - return failure(); - if (idx != definitions.size()) - return parser.emitError("Fewer RHS expressions than LHS definitions"); - - // 3. Postprocess. - // 3.a. Normalize all maps to the proper state.dims and symbols counts. - SmallVector allUses; - allUses.reserve(registeredTensors.size()); - for (auto &def : definitions) - allUses.push_back(def); - for (auto &pExpr : state.expressions) - visitPostorder(*pExpr, [&](const Expression &e) { - if (auto *use = dyn_cast(&e)) - allUses.push_back(*use); - }); - for (auto &use : allUses) - use.indexingMap = - AffineMap::get(state.dims.size(), symbols.size(), - use.indexingMap.getResults(), parser.context); - - // 3.b. Traverse definitions - llvm::DenseSet seenDefs; - for (auto &def : definitions) { - if (seenDefs.count(def.tensorId) > 0) - return parser.emitError("Unexpected multi-write to a single tensor"); - seenDefs.insert(def.tensorId); - auto tensorIter = registeredTensors.find(def.tensorId); - assert(tensorIter != registeredTensors.end() && "unregistered tensor"); - auto &tensor = tensorIter->getValue(); - tensor.indexingMap = def.indexingMap; - state.orderedTensorArgs[def] = tensor.index; - } - - bool failed = false; - for (auto &pExpr : state.expressions) - visitPostorder(*pExpr, [&](const Expression &e) { - auto *pUse = dyn_cast(&e); - if (failed || !pUse) - return; - auto &use = *pUse; - LLVM_DEBUG(llvm::dbgs() - << "\nuse: " << use.tensorId << " map: " << use.indexingMap); - auto tensorIter = registeredTensors.find(use.tensorId); - assert(tensorIter != registeredTensors.end() && "unregistered tensor"); - auto &tensor = tensorIter->getValue(); - if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 && - tensor.indexingMap.getResults() != use.indexingMap.getResults()) { - LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap); - (void)parser.emitError( - "Unexpected multi-read of a tensor with different accesses"); - failed = true; - return; - } - seenDefs.insert(use.tensorId); - tensor.indexingMap = use.indexingMap; - state.orderedTensorArgs[use] = tensor.index; - }); - // If more than one definitions are less. They are shaped-only operand, which - // are used to define reduction loops. For now, only accept exactly one - // shaped-only operand. - if (state.numArgs > seenDefs.size() + 1) { - failed = true; - } else if (state.numArgs == seenDefs.size() + 1) { - for (auto &tensorIter : registeredTensors) { - auto &tensor = tensorIter.getValue(); - if (tensor.indexingMap) - continue; - if (auto *pTensorExpr = - dyn_cast(state.expressions[0].get())) { - SmallVector exprs; - for (auto dim : pTensorExpr->reductionDimensions) - exprs.push_back(getAffineDimExpr(dim, parser.context)); - tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(), - exprs, parser.context); - } - } - } - if (failed) - return failure(); - - return success(); -} - -/// Parse and print the information for a ODS def. -/// -/// tensor-def-list ::= tensor-def (`,` tensor-def )* -/// attr-def-list ::= attr-def (`,` attr-def )* -/// -/// comprehension-list ::= comprehension comprehension* -/// -/// tc-attr-def ::= `attr` `(` attr-def-list `)` -/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` -/// (tc-attr-def)? -/// `{` comprehension-list `}` -/// -/// implements-interface ::= -/// `implements_interface` `<` bare-id (`,` bare-id)* `>` `:` tc-def -/// -/// ods-def ::= `ods_def` `<` bare-id `>` -/// (implements-interface)? `:` -/// tc-def -/// -/// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e. -/// contain only expressions involving symbols and constants), but can -/// otherwise contain arbitrary affine expressions. -LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { - // Parse ods-def header (including C++ op name) - if (failed(parser.parseToken(Token::Kind::kw_ods_def, - "expected 'ods_def' to define a TC ODS")) || - failed(parser.parseToken(Token::Kind::lt, "expected '<'"))) - return failure(); - StringRef cppOpName = parser.curToken.getSpelling(); - LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n"); - if (failed(parser.parseToken(Token::Kind::id, "expected id")) || - failed(parser.parseToken(Token::Kind::gt, "expected '>'"))) - return failure(); - - // Parse optional implements-interface header (including C++ op names) - SmallVector interfaces; - bool implementsInterface = succeeded( - parser.parseOptionalToken(Token::Kind::kw_implements_interface)); - if (implementsInterface) { - auto parseInterfaceString = [&]() -> LogicalResult { - StringRef interfaceName = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected id"))) - return failure(); - interfaces.push_back(interfaceName); - return success(); - }; - if (failed(parser.parseToken(Token::Kind::lt, "expected '<'")) || - failed(parser.parseCommaSeparatedListUntil( - Token::Kind::gt, parseInterfaceString, /*allowEmptyList=*/false))) - return failure(); - } - - // Parse column. - if (failed(parser.parseToken(Token::Kind::colon, "expected ':'"))) - return failure(); - - // Parse TC op name. - if (failed(parser.parseToken(Token::Kind::kw_def, - "expected 'def' to define a TC"))) - return failure(); - StringRef tcName = parser.curToken.getSpelling(); - LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n"); - - // Parse input/output tensor definitions - if (failed(parser.parseToken(Token::Kind::id, "expected id")) || - failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) - return failure(); - - auto parseInputDef = [&]() -> LogicalResult { - return parseTensorDef(/*isOutput=*/false); - }; - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false))) - return failure(); - - if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) || - failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || - failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) - return failure(); - auto parseOutputDef = [&]() -> LogicalResult { - return parseTensorDef(/*isOutput=*/true); - }; - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false))) - return failure(); - - // Parse optional attribute definitions - if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) { - if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) - return failure(); - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this), - /*allowEmptyList=*/false))) - return failure(); - } - - // Parse optional doc string - if (parser.curToken.is(Token::Kind::doc_str)) { - docString = parser.curToken.getSpelling(); - parser.consumeToken(); - LLVM_DEBUG(llvm::dbgs() - << "parsed doc string: '''" << docString << "'''\n"); - } - - // Since we don't declare symbols separately, we discover them eagerly: each - // newly encountered id in a tensor shape expression is treated as a new - // symbolic. At this point, all tensors have been parsed and all the symbols - // that could be discovered eagerly are now known. Resize all AffineMaps to - // normalize the number of eagerly discovered symbols. - for (auto &tensor : registeredTensors) { - auto &map = tensor.getValue().shape; - map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(), - parser.context); - } - - if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'"))) - return failure(); - - SmallVector perComprehensionStates; - while (parser.curToken.isNot(Token::Kind::r_brace)) { - perComprehensionStates.push_back(ComprehensionParsingState()); - perComprehensionStates.back().numArgs = registeredTensors.size(); - if (failed(parseOneComprehension(cppOpName, tcName, - perComprehensionStates.back()))) - return failure(); - }; - if (failed(parser.parseToken(Token::Kind::r_brace, "expected '}'"))) - return failure(); - - // Print. - auto nComprehensions = perComprehensionStates.size(); - if (nComprehensions != 1) - return parser.emitError("only 1 comprehension supported for now, got: " + - llvm::Twine(nComprehensions)); - if (genODSDecl) { - auto &state = perComprehensionStates.back(); - printODS(os, cppOpName, tcName, interfaces, state); - os << "\n"; - } - if (genODSImpl) { - auto &state = perComprehensionStates.back(); - std::string extraMethods; - llvm::raw_string_ostream ss(extraMethods); - printReferenceIterators(ss, cppOpName, state); - printIndexingMapRequiredAttrMethods(ss, cppOpName, state); - printReferenceIndexingMaps(ss, cppOpName, state); - printRegionBuilder(ss, cppOpName, state); - printCanonicalizersAndFolders(ss, cppOpName); - ss.flush(); - os << extraMethods << "\n"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing functions -//===----------------------------------------------------------------------===// - -/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. -void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, - StringRef linalgOpName, ArrayRef interfaces, - ComprehensionParsingState &state) { - SmallVector attributes; - for (const auto &attr : registeredAttrs) { - llvm::StringRef name = attr.first; - - llvm::StringRef elementType = attr.second.elementType; - std::string odsType = llvm::StringSwitch(elementType) - .Case("f32", "F32") - .Case("i32", "I32") - .Case("i64", "I64") - .Default(""); - if (odsType.empty()) { - (void)parser.emitError( - "unimplemented support for attribute element type: " + elementType); - return; - } - - const auto &dims = attr.second.vectorDims; - if (!dims.empty()) { - // Vector case - SmallVector dimStrs; - for (uint64_t dim : dims) - dimStrs.push_back(std::to_string(dim)); - odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType, - llvm::join(dimStrs, ", ")); - } else if (attr.second.isArray) { - // Array case - odsType = llvm::formatv("{0}ArrayAttr", odsType); - } else { - // Scalar case - odsType = llvm::formatv("{0}Attr", odsType); - } - - if (attr.second.isOptional) - odsType = llvm::formatv("OptionalAttr<{0}>", odsType); - - attributes.push_back(llvm::formatv("{0}:${1}", odsType, name)); - } - - std::string attrList = llvm::join(attributes, ",\n"); - if (!attrList.empty()) - attrList = ",\n" + attrList; - - // Template for Linalg named ops' ODS definitions. Parameters: - // {0}: ODS/C++ op name - // {1}: assembly op mnemonic - // {2}: op interface list - // {3}: documentation (summary + description) - // {4}: op attribute list - // {5}: the number of arguments for the op region - // {6}: builder methods taking standalone attribute parameters - // {7}: additional methods for attributes used by indexing maps - const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp"> - /*extraInterfaces=*/{2}]> { - {3} - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs{4} - ); - let results = (outs Variadic:$result_tensors); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"ArrayRef", "{{}">:$attributes), - [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - $_state.addAttributes(attributes); - createAndFillStructuredOpRegion<{0}>( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{{}">:$attributes), - [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - $_state.addAttributes(attributes); - createAndFillStructuredOpRegion<{0}>( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{{}">:$attributes), - [{{ - $_state.addOperands(operands); - $_state.addAttributes(attributes); - $_state.addTypes(resultTensorTypes); - (void)$_state.addRegion(); - }]> - {6} - ]; - let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; - let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result); - }]; - let hasFolder = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{{ - // Auto-generated. - ArrayAttr iterator_types(); - ArrayAttr indexing_maps(); - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); - static std::function - getRegionBuilder() {{ - return regionBuilder; - } - - // Generic methods. - static unsigned getNumRegionArgs() {{ return {5}; } - std::string getLibraryCallName() {{ - return generateLibraryCallName(getOperation()); - } - - {7} - }]; - })FMT"; - - // Generate the list of extra implemented interfaces. - std::string interfaceNameList; - if (!interfaces.empty()) { - llvm::raw_string_ostream ss(interfaceNameList); - ss << ", "; // Leading comma to concat to existing list of interfaces. - llvm::interleaveComma(interfaces, ss); - ss.flush(); - } - - // Generate documentation. - std::string doc; - if (!docString.empty()) { - const char *docFmt = R"FMT( - let summary = [{ {0} }]; - let description = [{ - {1} - }]; - )FMT"; - - StringRef summary, description; - std::tie(summary, description) = docString.trim().split('\n'); - doc = llvm::formatv(docFmt, summary.trim(), description.trim()); - } - - // Generate an additional builder that has parameters for attributes. - std::string attrBuilder; - if (!registeredAttrs.empty()) { - SmallVector attrParams, attrStmts; - for (const auto &attr : registeredAttrs) { - llvm::StringRef name = attr.first; - attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name)); - attrStmts.push_back( - llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name)); - } - std::string attrParamsList = llvm::join(attrParams, ", "); - std::string attrStmtsList = llvm::join(attrStmts, "\n"); - - const char *builderFmt = R"FMT( - , OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, {1}, - CArg<"ArrayRef", "{{}">:$attributes), - [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - $_state.addAttributes(attributes); - createAndFillStructuredOpRegion<{0}>( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - {2} - }]> - )FMT"; - attrBuilder = - llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList); - } - - std::string attrMethods; - if (!registeredAttrs.empty()) { - attrMethods = R"( - bool hasDynamicIndexingMaps(); - LogicalResult verifyIndexingMapRequiredAttributes(); - )"; - } - - // Finally put everything together. - os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc, - attrList, state.numArgs, attrBuilder, attrMethods); -} - -/// Print the C++ StructuredOpsInterface impl of `iterator_types`. -void TCParser::printReferenceIterators(llvm::raw_ostream &os, - StringRef cppOpName, - ComprehensionParsingState &state) { - const char *referenceReferenceIteratorsFmt = - R"FMT( - ArrayAttr {0}::iterator_types() { - return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); - })FMT"; - - std::string iteratorsStr; - llvm::raw_string_ostream ss(iteratorsStr); - unsigned pos = 0; - llvm::interleaveComma( - state.dims, ss, [&](std::pair p) { - bool reduction = false; - for (auto &expr : state.expressions) { - visitPostorder(*expr, [&](const Expression &e) { - if (auto *pTensorExpr = dyn_cast(&e)) { - if (pTensorExpr->reductionDimensions.count(pos) > 0) - reduction = true; - } - }); - if (reduction) - break; - } - ss << (reduction ? "getReductionIteratorTypeName()" - : "getParallelIteratorTypeName()"); - pos++; - }); - ss.flush(); - - os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr); -} - -void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os, - StringRef cppOpName) { - const char *foldersFmt = R"FMT( - LogicalResult {0}::fold(ArrayRef, - SmallVectorImpl &) {{ - return foldMemRefCast(*this); - } - void {0}::getEffects(SmallVectorImpl< - SideEffects::EffectInstance >&effects) {{ - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, - getOperation()->getResults(), inputBuffers, outputBuffers); - })FMT"; - os << llvm::formatv(foldersFmt, cppOpName); -} - -// Prints methods for querying whether the current named op has attributes that -// are used by its indexing maps and for verifying those attributes have the -// expected type. -void TCParser::printIndexingMapRequiredAttrMethods( - llvm::raw_ostream &os, StringRef cppOpName, - ComprehensionParsingState &state) { - // If there are no attribute used by the whole definition, then we are done. - if (registeredAttrs.empty()) - return; - - // Otherwise, go through each attribute and generate code to verify it's - // valid per the spec. - SmallVector attributes; - for (const auto &attr : registeredAttrs) { - if (attr.second.isOptional) - continue; - - llvm::StringRef name = attr.first; - llvm::StringRef elementType = attr.second.elementType; - const auto &dims = attr.second.vectorDims; - - // Get the method call to check the element type is of the expected kind. - std::string elemTypeCheck = llvm::StringSwitch(elementType) - .Case("f32", "isF32()") - .Case("i32", "isInteger(32)") - .Case("i64", "isInteger(64)") - .Default(""); - if (elemTypeCheck.empty()) { - (void)parser.emitError( - "unimplemented support for attribute element type: " + elementType); - return; - } - - // Scalar case. - if (dims.empty() && !attr.second.isArray) { - const char *attrFmt = R"FMT( - if (auto attr = op->getAttr("{0}")) {{ - if (!attr.getType().{1}) return op->emitError( - "incorrect type for indexing map required attribute '{0}'"); - } else {{ - return op->emitError( - "missing indexing map required attribute '{0}'"); - } - )FMT"; - - attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); - continue; - } - - // Vector case. - if (!dims.empty()) { - SmallVector dimStrs; - for (uint64_t dim : dims) - dimStrs.push_back(std::to_string(dim)); - - const char *attrFmt = R"FMT( - if (auto attr = op->getAttrOfType("{0}")) {{ - if (!attr.getType().getElementType().{1}) return op->emitError( - "incorrect element type for indexing map required attribute '{0}'"); - if (attr.getType().getShape() != ArrayRef{{ {2} }) - return op->emitError( - "incorrect shape for indexing map required attribute '{0}'"); - } else { - return op->emitError( - "missing indexing map required attribute '{0}'"); - } - )FMT"; - - attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck, - llvm::join(dimStrs, ", "))); - continue; - } - - // Array case. - { - const char *attrFmt = R"FMT( - if (auto attr = op->getAttrOfType("{0}")) {{ - for (Attribute element : attr) {{ - if (!element.getType().{1}) return emitError( - "incorrect element type for indexing map required attribute '{0}'"); - } - } else {{ - return op->emitError( - "missing indexing map required attribute '{0}'"); - } - )FMT"; - - attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); - } - } - - const char *methodFmt = R"FMT( - bool {0}::hasDynamicIndexingMaps() {{ return true; } - - LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ - Operation *op = getOperation(); - {1} - return success(); - } - )FMT"; - - // Print everything out. - os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n")); -} - -/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. -void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, - StringRef cppOpName, - ComprehensionParsingState &state) { - // 1. Generic string template for specifying reference indexing maps. - const char *referenceIndexingMapsFmt = - R"FMT( - // This is temporary until we transition out of manually specified ops that - // should be auto-generated with linalg-ods-gen. - ArrayAttr {0}::indexing_maps() { - MLIRContext *context = getContext(); - AffineExpr {1}; - bindDims(context, {1}); - {2} - return Builder(context).getAffineMapArrayAttr({ {3} }); - })FMT"; - - // 2. Print a comma-separated list of identifiers for the AffineExpr in - // `state.dims`. These will replace the `{1}` placeholder in both - // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr - // identifiers are bound in the right order to the proper AffineDimExpr. - std::string dimsStr; - llvm::raw_string_ostream ss(dimsStr); - llvm::interleaveComma( - state.dims, ss, - [&](std::pair p) { ss << p.second; }); - ss.flush(); - - // 3. Get the list of affine maps for each input/output. The AffineExpr use - // the common arithmetic operators on AffineExpr. These affine maps will - // replace the `{2}` placeholder. - std::string mapsStr; - llvm::raw_string_ostream mapsStringStream(mapsStr); - - // Create a list of all symbols. - SmallVector symbolReplacements; - symbolReplacements.reserve(symbols.size()); - for (unsigned i = 0; i < symbols.size(); ++i) { - const char *symFmt = - "\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};"; - mapsStringStream << llvm::formatv(symFmt, i); - symbolReplacements.push_back(llvm::formatv("s{0}", i)); - } - - // Create the affine constant expressions to replace symbols for attributes. - for (auto attrUse : llvm::enumerate(attrUses)) { - StringRef attrName = attrUse.value().attrName; - auto it = registeredAttrs.find(attrName.str()); - assert(it != registeredAttrs.end() && "uses should point to valid attr!"); - llvm::Optional getValueFn = - it->second.getValueFn(attrUse.value().indices); - if (!getValueFn) { - (void)parser.emitError("unimplemented getValueFn for attribute: " + - attrName); - return; - } - std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn); - const char *cstFmt = - "\n\tauto cst{0} = getAffineConstantExpr({1}, context);"; - mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal); - - unsigned position = - attrUse.value().symbol.cast().getPosition(); - symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index()); - } - - // For each registered tensor, construct the affine map, replace symbols by - // the corresponding attribute values, and simplify the affine map. - for (auto &tensorIter : registeredTensors) { - auto &tensor = tensorIter.getValue(); - auto indexingMap = tensor.indexingMap; - const char *mapFmt = - "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);"; - - std::string exprsStr; - llvm::raw_string_ostream exprsStringStream(exprsStr); - exprsStringStream << "{"; - llvm::interleaveComma(indexingMap.getResults(), exprsStringStream); - exprsStringStream << "}"; - exprsStringStream.flush(); - mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(), - indexingMap.getNumSymbols(), exprsStr); - - std::string replaceSymbolList = - llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", ")); - - // Note that we use `0` as the result affine map's number of symbols. All - // symbols representing attribute usages should be folded away. But there - // may exist additional symbols for tensor dimension upper bounds. Linalg - // does not handle such cases right now. This needs to be fixed once we - // need that. - const char *replaceFmt = - "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);"; - mapsStringStream << llvm::formatv(replaceFmt, tensor.index, - replaceSymbolList, state.dims.size()); - const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});"; - mapsStringStream << llvm::formatv(simplifyFmt, tensor.index); - } - - mapsStringStream.flush(); - - SmallVector mapList; - mapList.reserve(state.numArgs); - for (auto i : llvm::seq(0, state.numArgs)) - mapList.push_back(llvm::formatv("map{0}", i)); - - // 4. Apply format to 1. using 2. and 3. - os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr, - llvm::join(mapList, ", ")); -} - -/// Print the C++ StructuredOpsInterface impl of `regionBuilder`. -void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, - ComprehensionParsingState &state) { - unsigned count = state.numArgs; - llvm::DenseMap subExprsMap; - std::function printExpr; - printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void { - if (auto *pUse = dyn_cast(&e)) { - os << "_" << state.orderedTensorArgs.find(*pUse)->second; - return; - } - auto *pTensorExpr = cast(&e); - if (subExprsMap.count(pTensorExpr) > 0) { - os << "_" << subExprsMap[pTensorExpr]; - } else { - std::string subExprs; - llvm::raw_string_ostream subExprsStringStream(subExprs); - llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream, - [&](const std::unique_ptr &e) { - printExpr(subExprsStringStream, *e); - }); - subExprsStringStream.flush(); - const char *tensorExprFmt = "\n Value _{0} = b.create<{1}>({2});"; - os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, - subExprs); - subExprsMap[pTensorExpr] = count; - } - }; - - const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { - auto args = block.getArguments(); - Value {1}; - {2} - b.create(ValueRange{ {3} }); - })FMT"; - - std::string valueHandleStr; - llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); - std::set usedTensorId; - for (const auto &iter : state.orderedTensorArgs) - usedTensorId.insert(iter.second); - llvm::interleaveComma(usedTensorId, valueHandleStringStream, [&](auto idx) { - valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; - }); - - std::string expressionsStr; - llvm::raw_string_ostream expressionStringStream(expressionsStr); - for (auto &expr : state.expressions) - visitPostorder(*expr, [&](const Expression &e) { - if (e.kind == Expression::Kind::TensorExpr) - printExpr(expressionStringStream, e); - }); - expressionStringStream.flush(); - substituteOpAliases(expressionsStr); - - std::string yieldStr; - llvm::raw_string_ostream yieldStringStream(yieldStr); - llvm::interleaveComma(state.expressions, yieldStringStream, - [&](const std::unique_ptr &e) { - printExpr(yieldStringStream, *e); - }); - - valueHandleStringStream.flush(); - yieldStringStream.flush(); - - os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, - expressionsStr, yieldStr); -} - -llvm::Optional -TCParser::RegisteredAttr::getValueFn(ArrayRef indices) const { - if (isArray) - return llvm::None; - - if (!vectorDims.empty()) { - SmallVector indexStrs; - for (uint64_t index : indices) - indexStrs.push_back(std::to_string(index)); - std::string indexList = llvm::join(indexStrs, ", "); - if (elementType == "f32") - return llvm::formatv(".getValue({ {0} })", indexList).str(); - if (elementType == "i32") - return llvm::formatv(".getValue({ {0} })", indexList).str(); - if (elementType == "i64") - return llvm::formatv(".getValue({ {0} })", indexList).str(); - - return llvm::None; - } - - if (elementType == "f32") - return std::string(".convertToFloat()"); - if (elementType == "i32" || elementType == "i64") - return std::string(""); - return llvm::None; -} - -/// Iterate over each Tensor Comprehension def. -LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os, - Parser &parser) { - while (parser.curToken.getKind() != Token::Kind::eof) { - TCParser tcParser(parser); - if (failed(tcParser.parseAndEmitODSDef(os))) - return failure(); - } - return success(); -} - -int main(int argc, char **argv) { - llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen"); - - // Set up the input file. - std::string errorMessage; - std::unique_ptr file = - mlir::openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - std::unique_ptr output = - openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - exit(1); - } - - // Include the proper Linalg header for end-to-end tblgen testing without - // resorting to non-portable shell manipulations. - if (testEmitIncludeTdHeader) - output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\""; - - MLIRContext context; - llvm::SourceMgr mgr; - mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); - Parser parser(mgr, &context); - (void)parseAndEmitAllTensorComprehensions(output->os(), parser); - output->keep(); - - return 0; -} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5465,20 +5465,6 @@ ], ) -cc_binary( - name = "mlir-linalg-ods-gen", - srcs = [ - "tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp", - ], - deps = [ - ":IR", - ":Support", - "//llvm:Support", - "//llvm:TableGen", - "//llvm:config", - ], -) - cc_binary( name = "mlir-linalg-ods-yaml-gen", srcs = [ @@ -5911,22 +5897,6 @@ deps = [":LinalgOpsTdFiles"], ) -genlinalg( - name = "LinalgNamedStructuredOpsTcIncGen", - src = "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc", - linalg_outs = [ - ( - "-gen-impl -o=$@", - "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc", - ), - ( - "-gen-ods-decl -o=$@", - "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td", - ), - ], - linalggen = ":mlir-linalg-ods-gen", -) - genlinalg( name = "LinalgNamedStructuredOpsYamlIncGen", src = "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml", @@ -5947,7 +5917,6 @@ name = "LinalgStructuredOpsTdFiles", srcs = [ "include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td", - "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td", "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td", "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td", ], @@ -6123,7 +6092,6 @@ ":InferTypeOpInterface", ":LinalgInterfaces", ":LinalgInterfacesIncGen", - ":LinalgNamedStructuredOpsTcIncGen", ":LinalgNamedStructuredOpsYamlIncGen", ":LinalgOpsIncGen", ":LinalgStructuredOpsIncGen",