diff --git a/drivers/clk/imx/clk-composite-8m.c b/drivers/clk/imx/clk-composite-8m.c
index 560d74aac80..2cb7d135000 100644
--- a/drivers/clk/imx/clk-composite-8m.c
+++ b/drivers/clk/imx/clk-composite-8m.c
@@ -117,6 +117,41 @@ static const struct clk_ops imx8m_clk_composite_divider_ops = {
 	.set_rate = imx8m_clk_composite_divider_set_rate,
 };
 
+static int imx8m_clk_mux_set_parent(struct clk *clk, struct clk *parent)
+{
+	struct clk_mux *mux = to_clk_mux(clk);
+	int index;
+	u32 val;
+	u32 reg;
+
+	index = clk_mux_fetch_parent_index(clk, parent);
+	if (index < 0) {
+		log_err("Could not fetch index\n");
+		return index;
+	}
+
+	val = clk_mux_index_to_val(mux->table, mux->flags, index);
+
+	reg = readl(mux->reg);
+	reg &= ~(mux->mask << mux->shift);
+	val = val << mux->shift;
+	reg |= val;
+
+	/*
+	 * write twice to make sure non-target interface
+	 * SEL_A/B point the same clk input.
+	 */
+	writel(reg, mux->reg);
+	writel(reg, mux->reg);
+
+	return 0;
+}
+
+const struct clk_ops imx8m_clk_mux_ops = {
+	.get_rate = clk_generic_get_rate,
+	.set_parent = imx8m_clk_mux_set_parent,
+};
+
 struct clk *imx8m_clk_composite_flags(const char *name,
 				      const char * const *parent_names,
 				      int num_parents, void __iomem *reg,
@@ -155,7 +190,7 @@ struct clk *imx8m_clk_composite_flags(const char *name,
 
 	clk = clk_register_composite(NULL, name,
 				     parent_names, num_parents,
-				     &mux->clk, &clk_mux_ops, &div->clk,
+				     &mux->clk, &imx8m_clk_mux_ops, &div->clk,
 				     &imx8m_clk_composite_divider_ops,
 				     &gate->clk, &clk_gate_ops, flags);
 	if (IS_ERR(clk))
