diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/AggregatedOpInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -14,6 +14,7 @@
 #define LINALG_OPS
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Interfaces/AggregatedOpInterface.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -93,6 +94,7 @@
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1155,6 +1155,33 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// DecomposeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    TODO
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+      "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/AggregatedOpInterface.h b/mlir/include/mlir/Interfaces/AggregatedOpInterface.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/AggregatedOpInterface.h
@@ -0,0 +1,28 @@
+//===- AggregatedOpInterface.h - Interface for aggregated ops ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the AggregatedOpInterface defined in
+// `AggregatedOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
+#define MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Support/LLVM.h"
+
+
+/// Include the ODS generated interface header files.
+#include "mlir/Interfaces/AggregatedOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/AggregatedOpInterface.td b/mlir/include/mlir/Interfaces/AggregatedOpInterface.td
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/AggregatedOpInterface.td
@@ -0,0 +1,45 @@
+//===- TilingInterface.td - Interface for tiling operations *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains an interface to allow operations to generate a sequence of
+// simpler ops of themselves.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_AGGREGATEDOPINTERFACE
+#define MLIR_AGGREGATEDOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
+  let description = [{
+    Interface for decomposing aggregated operations into a sequence of simpler
+    ops.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to decompose the operation into simpler operations.
+
+          On success, this method returns the `Value`s that can be used
+          to replace the results of the current operation.
+          Therefore, it must return one value per result.
+        }],
+        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*methodName=*/"decomposeOperation",
+        /*args=*/(ins
+            "OpBuilder &":$b),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return {};
+        }]
+      >
+  ];
+}
+#endif // MLIR_AGGREGATEDOPINTERFACE
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_mlir_interface(AggregatedOpInterface)
 add_mlir_interface(CallInterfaces)
 add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -16,6 +16,7 @@
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRAggregatedOpInterface
   MLIRArithDialect
   MLIRArithUtils
   MLIRBufferizationDialect
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2285,6 +2285,177 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+// Helper functions for softmax decomposition.
+// @{
+
+// Helper function to produce the iterator types (reduction or parallel) and
+// affine maps for the iterators used in the decomposition of softmax.
+// This method creates:
+// If allParallel == true:
+// - iterator type: {parallel, ..., parallel}
+// - affine maps:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+//
+// If allParallel == false:
+// - iterator type at dim(i) == parallel for i == \p dim and
+//   dim(dim) == reduction.
+// - affine map:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
+computeIteratorTypesAndIndexingMaps(int64_t inputRank, int64_t dim,
+                                    OpBuilder &builder,
+                                    bool allParallel = false) {
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
+  if (!allParallel)
+    iteratorTypes[dim] = utils::IteratorType::reduction;
+  auto identityMap =
+      AffineMap::getMultiDimIdentityMap(inputRank, builder.getContext());
+  SmallVector<AffineExpr, 2> affineExprs;
+  for (int i = 0; i < inputRank; i++) {
+    if (i != dim)
+      affineExprs.push_back(mlir::getAffineDimExpr(i, builder.getContext()));
+  }
+  auto reductionMap =
+      AffineMap::get(inputRank, 0, affineExprs, builder.getContext());
+  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
+  return std::make_tuple(iteratorTypes, indexingMaps);
+}
+
+// Helper function to produce a linalg.generic that computes a reduction on
+// dimension \p dim with the operation type \p T.
+template <typename T>
+static Value reduce(Value input, Value output, int64_t dim, Location loc,
+                    OpBuilder &builder) {
+  auto inputType = input.getType().cast<ShapedType>();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] =
+      computeIteratorTypesAndIndexingMaps(inputRank, dim, builder);
+  assert(indexingMaps.size() == 2 &&
+         "We should two maps: 1 for the input, 1 for the output");
+  assert(indexingMaps[0].isIdentity() && "input map should be identity");
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), input, output, indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<T>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the second step of the softmax
+/// decomposition: res = exp(input - max), where \p max is the max of \p input
+/// on dimension \p dim.
+static Value subtractAndExp(Value input, Value max, Value output, int64_t dim,
+                            Location loc, OpBuilder &builder) {
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      inputRank, dim, builder, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 && "We should one map for each input");
+  assert(indexingMaps[0].isIdentity() &&
+         "All parallel should give us identity map for input");
+  // Add the affine map for the output argument.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
+      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
+        Value result = b.create<math::ExpOp>(loc, diff);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the final step of the softmax
+/// decomposition.
+/// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
+///   yield  n / d
+/// }
+static Value computeSoftmax(Value numerator, Value denominator, Value output,
+                            int64_t dim, Location loc, OpBuilder &builder) {
+  auto inputType = numerator.getType().cast<ShapedType>();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      inputRank, dim, builder, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 && "We should one map for each input (2)");
+  assert(indexingMaps[0].isIdentity() &&
+         "All parallel should give us identity map");
+  // Add the affine map for the output tensor.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, numerator.getType(), ValueRange{numerator, denominator}, output,
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+// @} End helper functions for softmax decomposition.
+
+/// Given an N-dimensional tensor x, this method converts
+/// softmax(x) to the following sequence of operations:
+///
+/// 1. Compute the max of x along dimension d. This results
+///    in a N-1 dimensional tensor m.
+///    m = max(x, dim = d)
+///
+/// 2. Subtract m from x and exponentiate. This results in
+///    a N dimensional tensor z.
+///    z = exp(x - m)
+///
+/// 3. Compute the sum of z along dimension d. This results in
+///    a N-1 dimensional tensor l.
+///    l = sum(z, dim = d)
+///
+/// 4. Divide z and l. This gives the N-dimensional softmax.
+///    softmax = z / l
+///
+FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(*this);
+  Location loc = getLoc();
+  Value input = getInput();
+  ShapedType inputType = getInputOperandType();
+  Type elementType = inputType.getElementType();
+  int64_t reductionDim = getDimension();
+  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
+  Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  dims.erase(dims.begin() + reductionDim);
+  // Step 1: Compute max along dim.
+  Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value neutralForMaxF =
+      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
+  Value neutralForMaxFInit =
+      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
+  Value max =
+      reduce<arith::MaxFOp>(input, neutralForMaxFInit, reductionDim, loc, b);
+
+  // Step 2: Subtract max from input and exponentiate.
+  Value numerator = subtractAndExp(input, max, outputNd, reductionDim, loc, b);
+
+  // Step 3: Compute sum along dim.
+  Value zero =
+      arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
+  Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
+  Value denominator =
+      reduce<arith::AddFOp>(numerator, zeroInit, reductionDim, loc, b);
+
+  // Step 4: Compute softmax.
+  Value result =
+      computeSoftmax(numerator, denominator, outputNd, reductionDim, loc, b);
+  return SmallVector<Value>{result};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/AggregatedOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/TypeID.h"
@@ -236,6 +237,29 @@
   return emitDefaultSilenceableFailure(target);
 }
 
+/// Quick impl.
+DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+  if (!decomposableOp) {
+    failed(
+        rewriter.notifyMatchFailure(target, "payload is not a decomposable"));
+    return emitDefaultSilenceableFailure(target);
+  }
+
+  FailureOr<SmallVector<Value>> maybeNewResults =
+      decomposableOp.decomposeOperation(rewriter);
+  if (succeeded(maybeNewResults)) {
+    rewriter.replaceOp(decomposableOp, *maybeNewResults);
+    // FIXME: Quick hack to connect things.
+    results.push_back((*maybeNewResults)[0].getDefiningOp());
+    return DiagnosedSilenceableFailure::success();
+  }
+  return emitDefaultSilenceableFailure(target);
+}
+
 //===----------------------------------------------------------------------===//
 // EliminateLinalgOpAnchoredEmptyTensorsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/AggregatedOpInterface.cpp b/mlir/lib/Interfaces/AggregatedOpInterface.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Interfaces/AggregatedOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- AggregatedOpInterface.cpp - Aggregated op interface ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the interface in
+// `AggregatedOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/AggregatedOpInterface.h"
+
+using namespace mlir;
+
+#include "mlir/Interfaces/AggregatedOpInterface.cpp.inc"
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(LLVM_OPTIONAL_SOURCES
+  AggregatedOpInterface.cpp
   CallInterfaces.cpp
   CastInterfaces.cpp
   ControlFlowInterfaces.cpp
@@ -36,6 +37,7 @@
 endfunction(add_mlir_interface_library)
 
 
+add_mlir_interface_library(AggregatedOpInterface)
 add_mlir_interface_library(CallInterfaces)
 add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
 
+// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -199,8 +202,54 @@
   return %0 : tensor<?x?x1x?xf32>
 }
 
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL:      func.func @softmax(
+//CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFF800000 : f32
+// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
+// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
+// CHECK:      }
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
+
+  %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
 }